From 90ac3e2cf65b29f7959a4bcb0d828526625d0256 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 2 Aug 2019 17:15:31 +0200
Subject: [PATCH] Add tests for CRSMatrix and little clean-up

---
 src/algebra/CRSMatrix.hpp |  23 ++-----
 tests/CMakeLists.txt      |   1 +
 tests/test_CRSMatrix.cpp  | 132 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 140 insertions(+), 16 deletions(-)
 create mode 100644 tests/test_CRSMatrix.cpp

diff --git a/src/algebra/CRSMatrix.hpp b/src/algebra/CRSMatrix.hpp
index 491a35242..224b09c21 100644
--- a/src/algebra/CRSMatrix.hpp
+++ b/src/algebra/CRSMatrix.hpp
@@ -21,6 +21,13 @@ class CRSMatrix
   Array<DataType> m_values;
 
  public:
+  PUGS_INLINE
+  size_t
+  numberOfRows() const
+  {
+    return m_host_matrix.numRows();
+  }
+
   Vector<DataType> operator*(const Vector<DataType>& x) const
   {
     Vector<DataType> Ax{m_host_matrix.numRows()};
@@ -40,22 +47,6 @@ class CRSMatrix
     return Ax;
   }
 
-  friend std::ostream&
-  operator<<(std::ostream& os, const CRSMatrix& M)
-  {
-    auto host_row_map = M.m_host_matrix.row_map;
-    for (IndexType i_row = 0; i_row < M.m_host_matrix.numRows(); ++i_row) {
-      const auto& row_begin = host_row_map(i_row);
-      const auto& row_end   = host_row_map(i_row + 1);
-      os << i_row << " #";
-      for (IndexType j = row_begin; j < row_end; ++j) {
-        os << ' ' << M.m_host_matrix.entries(j) << ':' << M.m_values[j];
-      }
-      os << '\n';
-    }
-    return os;
-  }
-
   CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M)
   {
     m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector());
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 1c567b6ea..2d0ec377d 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -5,6 +5,7 @@ add_executable (unit_tests
   test_main.cpp
   test_Array.cpp
   test_ArrayUtils.cpp
+  test_CRSMatrix.cpp
   test_ItemType.cpp
   test_PugsAssert.cpp
   test_RevisionInfo.cpp
diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp
new file mode 100644
index 000000000..4933173e2
--- /dev/null
+++ b/tests/test_CRSMatrix.cpp
@@ -0,0 +1,132 @@
+#include <catch2/catch.hpp>
+
+#include <CRSMatrix.hpp>
+
+// Instantiate to ensure full coverage is performed
+template class SparseMatrixDescriptor<int, uint8_t>;
+
+TEST_CASE("CRSMatrix", "[algebra]")
+{
+  SECTION("CRSMatrix")
+  {
+    SECTION("matrix size")
+    {
+      SparseMatrixDescriptor<int, uint8_t> S{5};
+      S(0, 2) = 3;
+      S(1, 2) = 11;
+      S(1, 1) = 1;
+      S(3, 0) = 5;
+      S(3, 1) = -3;
+      S(4, 1) = 1;
+      S(2, 2) = 4;
+      S(4, 3) = 2;
+      S(4, 4) = -2;
+
+      CRSMatrix<int, uint8_t> A{S};
+
+      REQUIRE(A.numberOfRows() == S.numberOfRows());
+    }
+
+    SECTION("matrix vector product (simple)")
+    {
+      SparseMatrixDescriptor<int, uint8_t> S{5};
+      S(0, 2) = 3;
+      S(1, 2) = 11;
+      S(1, 1) = 1;
+      S(3, 0) = 5;
+      S(3, 1) = -3;
+      S(4, 1) = 1;
+      S(2, 2) = 4;
+      S(4, 3) = 2;
+      S(4, 4) = -2;
+
+      CRSMatrix<int, uint8_t> A{S};
+
+      Vector<int> x{A.numberOfRows()};
+      x    = 0;
+      x[0] = 1;
+
+      Vector<int> y = A * x;
+      REQUIRE(y[0] == 0);
+      REQUIRE(y[1] == 0);
+      REQUIRE(y[2] == 0);
+      REQUIRE(y[3] == 5);
+      REQUIRE(y[4] == 0);
+
+      x    = 0;
+      x[1] = 2;
+
+      y = A * x;
+      REQUIRE(y[0] == 0);
+      REQUIRE(y[1] == 2);
+      REQUIRE(y[2] == 0);
+      REQUIRE(y[3] == -6);
+      REQUIRE(y[4] == 2);
+
+      x    = 0;
+      x[2] = -1;
+
+      y = A * x;
+      REQUIRE(y[0] == -3);
+      REQUIRE(y[1] == -11);
+      REQUIRE(y[2] == -4);
+      REQUIRE(y[3] == 0);
+      REQUIRE(y[4] == 0);
+
+      x    = 0;
+      x[3] = 3;
+
+      y = A * x;
+      REQUIRE(y[0] == 0);
+      REQUIRE(y[1] == 0);
+      REQUIRE(y[2] == 0);
+      REQUIRE(y[3] == 0);
+      REQUIRE(y[4] == 6);
+
+      x    = 0;
+      x[4] = 1;
+
+      y = A * x;
+      REQUIRE(y[0] == 0);
+      REQUIRE(y[1] == 0);
+      REQUIRE(y[2] == 0);
+      REQUIRE(y[3] == 0);
+      REQUIRE(y[4] == -2);
+    }
+
+    SECTION("matrix vector product (complet)")
+    {
+      SparseMatrixDescriptor<int, uint8_t> S{4};
+      S(0, 0) = 1;
+      S(0, 1) = 2;
+      S(0, 2) = 3;
+      S(0, 3) = 4;
+      S(1, 0) = 5;
+      S(1, 1) = 6;
+      S(1, 2) = 7;
+      S(1, 3) = 8;
+      S(2, 0) = 9;
+      S(2, 1) = 10;
+      S(2, 2) = 11;
+      S(2, 3) = 12;
+      S(3, 0) = 13;
+      S(3, 1) = 14;
+      S(3, 2) = 15;
+      S(3, 3) = 16;
+
+      CRSMatrix<int, uint8_t> A{S};
+
+      Vector<int> x{A.numberOfRows()};
+      x[0] = 1;
+      x[1] = 2;
+      x[2] = 3;
+      x[3] = 4;
+
+      Vector<int> y = A * x;
+      REQUIRE(y[0] == 30);
+      REQUIRE(y[1] == 70);
+      REQUIRE(y[2] == 110);
+      REQUIRE(y[3] == 150);
+    }
+  }
+}
-- 
GitLab