From b24312838a4becbc6655ccef07a0829b8664b95f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Thu, 24 Jun 2021 16:04:54 +0200
Subject: [PATCH] Improve PETSc interface

---
 src/algebra/CMakeLists.txt              |  2 +-
 src/algebra/LinearSolver.cpp            | 37 +++----------
 src/algebra/PETScUtils.cpp              | 74 +++++++++++++++++++++++++
 src/algebra/PETScUtils.hpp              | 65 ++++++++++++++++++++++
 src/utils/CMakeLists.txt                |  1 +
 src/{algebra => utils}/PETScWrapper.cpp |  2 +-
 src/{algebra => utils}/PETScWrapper.hpp |  0
 src/utils/PugsUtils.cpp                 |  2 +-
 tests/mpi_test_main.cpp                 |  2 +-
 tests/test_main.cpp                     |  2 +-
 10 files changed, 152 insertions(+), 35 deletions(-)
 create mode 100644 src/algebra/PETScUtils.cpp
 create mode 100644 src/algebra/PETScUtils.hpp
 rename src/{algebra => utils}/PETScWrapper.cpp (92%)
 rename src/{algebra => utils}/PETScWrapper.hpp (100%)

diff --git a/src/algebra/CMakeLists.txt b/src/algebra/CMakeLists.txt
index d1f01ec31..5fd848bca 100644
--- a/src/algebra/CMakeLists.txt
+++ b/src/algebra/CMakeLists.txt
@@ -4,4 +4,4 @@ add_library(
   PugsAlgebra
   LinearSolver.cpp
   LinearSolverOptions.cpp
-  PETScWrapper.cpp)
+  PETScUtils.cpp)
diff --git a/src/algebra/LinearSolver.cpp b/src/algebra/LinearSolver.cpp
index b3a4d9816..a7c6188c5 100644
--- a/src/algebra/LinearSolver.cpp
+++ b/src/algebra/LinearSolver.cpp
@@ -3,6 +3,7 @@
 
 #include <algebra/BiCGStab.hpp>
 #include <algebra/CG.hpp>
+#include <algebra/PETScUtils.hpp>
 
 #ifdef PUGS_HAS_PETSC
 #include <petsc.h>
@@ -172,42 +173,20 @@ struct LinearSolver::Internals
                         const Vector<double>& b,
                         const LinearSolverOptions& options)
   {
-    Assert(x.size() == b.size() and x.size() == A.numberOfRows());
+    Assert(x.size() == b.size() and x.size() == A.numberOfColumns() and A.isSquare());
 
     Vec petscB;
-    VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, b.size(), b.size(), &b[0], &petscB);
+    VecCreateMPIWithArray(PETSC_COMM_SELF, 1, b.size(), b.size(), &b[0], &petscB);
     Vec petscX;
-    VecCreateMPIWithArray(PETSC_COMM_WORLD, 1, x.size(), x.size(), &x[0], &petscX);
+    VecCreateMPIWithArray(PETSC_COMM_SELF, 1, x.size(), x.size(), &x[0], &petscX);
 
-    Array<PetscScalar> values = copy(A.values());
-
-    const auto A_row_indices = A.rowIndices();
-    Array<PetscInt> row_indices{A_row_indices.size()};
-    for (size_t i = 0; i < row_indices.size(); ++i) {
-      row_indices[i] = A_row_indices[i];
-    }
-
-    Array<PetscInt> column_indices{values.size()};
-    size_t l = 0;
-    for (size_t i = 0; i < A.numberOfRows(); ++i) {
-      const auto row_i = A.row(i);
-      for (size_t j = 0; j < row_i.length; ++j) {
-        column_indices[l++] = row_i.colidx(j);
-      }
-    }
-
-    Mat petscMat;
-    MatCreateSeqAIJWithArrays(PETSC_COMM_WORLD, x.size(), x.size(), &row_indices[0], &column_indices[0], &values[0],
-                              &petscMat);
-
-    MatAssemblyBegin(petscMat, MAT_FINAL_ASSEMBLY);
-    MatAssemblyEnd(petscMat, MAT_FINAL_ASSEMBLY);
+    PETScAijMatrixEmbedder petscA(A);
 
     KSP ksp;
-    KSPCreate(PETSC_COMM_WORLD, &ksp);
+    KSPCreate(PETSC_COMM_SELF, &ksp);
     KSPSetTolerances(ksp, options.epsilon(), 1E-100, 1E5, options.maximumIteration());
 
-    KSPSetOperators(ksp, petscMat, petscMat);
+    KSPSetOperators(ksp, petscA, petscA);
 
     PC pc;
     KSPGetPC(ksp, &pc);
@@ -288,8 +267,6 @@ struct LinearSolver::Internals
 
     KSPSolve(ksp, petscB, petscX);
 
-    // free used memory
-    MatDestroy(&petscMat);
     VecDestroy(&petscB);
     VecDestroy(&petscX);
     KSPDestroy(&ksp);
diff --git a/src/algebra/PETScUtils.cpp b/src/algebra/PETScUtils.cpp
new file mode 100644
index 000000000..5ed261353
--- /dev/null
+++ b/src/algebra/PETScUtils.cpp
@@ -0,0 +1,74 @@
+#include <algebra/PETScUtils.hpp>
+
+#ifdef PUGS_HAS_PETSC
+
+PETScAijMatrixEmbedder::PETScAijMatrixEmbedder(const size_t nb_rows, const size_t nb_columns, const double* A)
+  : m_nb_rows{nb_rows}, m_nb_columns{nb_columns}
+{
+  MatCreate(PETSC_COMM_SELF, &m_petscMat);
+  MatSetSizes(m_petscMat, PETSC_DECIDE, PETSC_DECIDE, nb_rows, nb_columns);
+  MatSetFromOptions(m_petscMat);
+  MatSetType(m_petscMat, MATAIJ);
+  MatSetUp(m_petscMat);
+
+  {
+    Array<PetscInt> row_indices(nb_rows);
+    for (size_t i = 0; i < nb_rows; ++i) {
+      row_indices[i] = i;
+    }
+    m_row_indices = row_indices;
+  }
+
+  if (nb_rows == nb_columns) {
+    m_column_indices = m_row_indices;
+  } else {
+    Array<PetscInt> column_indices(nb_columns);
+    for (size_t i = 0; i < nb_columns; ++i) {
+      column_indices[i] = i;
+    }
+    m_column_indices = column_indices;
+  }
+
+  MatSetValuesBlocked(m_petscMat, nb_rows, &(m_row_indices[0]), nb_columns, &(m_column_indices[0]), A, INSERT_VALUES);
+
+  MatAssemblyBegin(m_petscMat, MAT_FINAL_ASSEMBLY);
+  MatAssemblyEnd(m_petscMat, MAT_FINAL_ASSEMBLY);
+}
+
+PETScAijMatrixEmbedder::PETScAijMatrixEmbedder(const CRSMatrix<double, size_t>& A)
+  : m_nb_rows{A.numberOfRows()}, m_nb_columns{A.numberOfColumns()}
+{
+  const Array<PetscReal>& values = copy(A.values());
+
+  {
+    const auto& A_row_indices = A.rowIndices();
+    Array<PetscInt> row_indices{A_row_indices.size()};
+    for (size_t i = 0; i < row_indices.size(); ++i) {
+      row_indices[i] = A_row_indices[i];
+    }
+    m_row_indices = row_indices;
+
+    Array<PetscInt> column_indices{values.size()};
+    size_t l = 0;
+    for (size_t i = 0; i < A.numberOfRows(); ++i) {
+      const auto row_i = A.row(i);
+      for (size_t j = 0; j < row_i.length; ++j) {
+        column_indices[l++] = row_i.colidx(j);
+      }
+    }
+    m_column_indices = column_indices;
+  }
+
+  MatCreateSeqAIJWithArrays(PETSC_COMM_SELF, A.numberOfRows(), A.numberOfColumns(), &m_row_indices[0],
+                            &m_column_indices[0], &values[0], &m_petscMat);
+
+  MatAssemblyBegin(m_petscMat, MAT_FINAL_ASSEMBLY);
+  MatAssemblyEnd(m_petscMat, MAT_FINAL_ASSEMBLY);
+}
+
+PETScAijMatrixEmbedder::~PETScAijMatrixEmbedder()
+{
+  MatDestroy(&m_petscMat);
+}
+
+#endif   // PUGS_HAS_PETSC
diff --git a/src/algebra/PETScUtils.hpp b/src/algebra/PETScUtils.hpp
new file mode 100644
index 000000000..cdb092545
--- /dev/null
+++ b/src/algebra/PETScUtils.hpp
@@ -0,0 +1,65 @@
+#ifndef PETSC_UTILS_HPP
+#define PETSC_UTILS_HPP
+
+#include <utils/pugs_config.hpp>
+
+#ifdef PUGS_HAS_PETSC
+
+#include <algebra/CRSMatrix.hpp>
+#include <algebra/DenseMatrix.hpp>
+#include <algebra/TinyMatrix.hpp>
+
+#include <petsc.h>
+
+class PETScAijMatrixEmbedder
+{
+ private:
+  Mat m_petscMat;
+  Array<PetscInt> m_row_indices;
+  Array<PetscInt> m_column_indices;
+  const size_t m_nb_rows;
+  const size_t m_nb_columns;
+
+  PETScAijMatrixEmbedder(const size_t nb_rows, const size_t nb_columns, const double* A);
+
+ public:
+  PUGS_INLINE
+  size_t
+  numberOfRows() const
+  {
+    return m_nb_rows;
+  }
+
+  PUGS_INLINE
+  size_t
+  numberOfColumns() const
+  {
+    return m_nb_columns;
+  }
+
+  PUGS_INLINE
+  operator Mat&()
+  {
+    return m_petscMat;
+  }
+
+  PUGS_INLINE
+  operator const Mat&() const
+  {
+    return m_petscMat;
+  }
+
+  template <size_t N>
+  PETScAijMatrixEmbedder(const TinyMatrix<N>& A) : PETScAijMatrixEmbedder{N, N, &A(0, 0)}
+  {}
+
+  PETScAijMatrixEmbedder(const DenseMatrix<double>& A) : PETScAijMatrixEmbedder{A.nbRows(), A.nbColumns(), &A(0, 0)} {}
+
+  PETScAijMatrixEmbedder(const CRSMatrix<double, size_t>& A);
+
+  ~PETScAijMatrixEmbedder();
+};
+
+#endif   // PUGS_HAS_PETSC
+
+#endif   // PETSC_UTILS_HPP
diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt
index fe9b5e6fe..3e2e1c15b 100644
--- a/src/utils/CMakeLists.txt
+++ b/src/utils/CMakeLists.txt
@@ -10,6 +10,7 @@ add_library(
   FPEManager.cpp
   Messenger.cpp
   Partitioner.cpp
+  PETScWrapper.cpp
   PugsUtils.cpp
   RevisionInfo.cpp
   SignalManager.cpp)
diff --git a/src/algebra/PETScWrapper.cpp b/src/utils/PETScWrapper.cpp
similarity index 92%
rename from src/algebra/PETScWrapper.cpp
rename to src/utils/PETScWrapper.cpp
index dd11dc977..f88bbfd8a 100644
--- a/src/algebra/PETScWrapper.cpp
+++ b/src/utils/PETScWrapper.cpp
@@ -1,4 +1,4 @@
-#include <algebra/PETScWrapper.hpp>
+#include <utils/PETScWrapper.hpp>
 
 #include <utils/pugs_config.hpp>
 
diff --git a/src/algebra/PETScWrapper.hpp b/src/utils/PETScWrapper.hpp
similarity index 100%
rename from src/algebra/PETScWrapper.hpp
rename to src/utils/PETScWrapper.hpp
diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp
index 4c16444b3..fbf2b04e3 100644
--- a/src/utils/PugsUtils.cpp
+++ b/src/utils/PugsUtils.cpp
@@ -1,10 +1,10 @@
 #include <utils/PugsUtils.hpp>
 
-#include <algebra/PETScWrapper.hpp>
 #include <utils/BuildInfo.hpp>
 #include <utils/ConsoleManager.hpp>
 #include <utils/FPEManager.hpp>
 #include <utils/Messenger.hpp>
+#include <utils/PETScWrapper.hpp>
 #include <utils/RevisionInfo.hpp>
 #include <utils/SignalManager.hpp>
 #include <utils/pugs_build_info.hpp>
diff --git a/tests/mpi_test_main.cpp b/tests/mpi_test_main.cpp
index 68736ebe3..51a01b0cf 100644
--- a/tests/mpi_test_main.cpp
+++ b/tests/mpi_test_main.cpp
@@ -2,13 +2,13 @@
 
 #include <Kokkos_Core.hpp>
 
-#include <algebra/PETScWrapper.hpp>
 #include <language/utils/OperatorRepository.hpp>
 #include <mesh/DiamondDualConnectivityManager.hpp>
 #include <mesh/DiamondDualMeshManager.hpp>
 #include <mesh/MeshDataManager.hpp>
 #include <mesh/SynchronizerManager.hpp>
 #include <utils/Messenger.hpp>
+#include <utils/PETScWrapper.hpp>
 #include <utils/pugs_config.hpp>
 
 #include <MeshDataBaseForTests.hpp>
diff --git a/tests/test_main.cpp b/tests/test_main.cpp
index ee1d0769b..89e55c3d7 100644
--- a/tests/test_main.cpp
+++ b/tests/test_main.cpp
@@ -2,13 +2,13 @@
 
 #include <Kokkos_Core.hpp>
 
-#include <algebra/PETScWrapper.hpp>
 #include <language/utils/OperatorRepository.hpp>
 #include <mesh/DiamondDualConnectivityManager.hpp>
 #include <mesh/DiamondDualMeshManager.hpp>
 #include <mesh/MeshDataManager.hpp>
 #include <mesh/SynchronizerManager.hpp>
 #include <utils/Messenger.hpp>
+#include <utils/PETScWrapper.hpp>
 
 #include <MeshDataBaseForTests.hpp>
 
-- 
GitLab