From 6b19154e963c74eabca9aa7a095f5ea7ebb22a3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Wed, 4 Nov 2020 00:26:52 +0100 Subject: [PATCH] Add missing tests for CRSMatrix Change a bit the interface to ensure that matrix structure cannot be modified once the matrix is built. Actually the interface of the underlying `Kokkos::StaticCrsGraph` is quite crappy, especially to prevent matrix changes. Replacing `Kokkos::StaticCrsGraph` class by an appropriate builtin one should be considered: one could then provide a better interface, and offer a better coupling with external linear algebra. Also, this would reduce the direct dependency on Kokkos. --- src/algebra/CRSMatrix.hpp | 25 ++++++++++++---- src/algebra/LinearSolver.cpp | 10 +++---- tests/test_CRSMatrix.cpp | 56 +++++++++++++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/algebra/CRSMatrix.hpp b/src/algebra/CRSMatrix.hpp index d9b03ac93..98297d759 100644 --- a/src/algebra/CRSMatrix.hpp +++ b/src/algebra/CRSMatrix.hpp @@ -17,7 +17,7 @@ class CRSMatrix using MutableDataType = std::remove_const_t<DataType>; private: - using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>; + using HostMatrix = Kokkos::StaticCrsGraph<const IndexType, Kokkos::HostSpace>; HostMatrix m_host_matrix; Array<const DataType> m_values; @@ -37,9 +37,15 @@ class CRSMatrix } auto - hostMatrix() const + rowIndices() const { - return m_host_matrix; + return encapsulate(m_host_matrix.row_map); + } + + auto + row(size_t i) const + { + return m_host_matrix.rowConst(i); } template <typename DataType2> @@ -70,8 +76,17 @@ class CRSMatrix CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M) { - m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector()); - m_values = M.valueArray(); + { + auto host_matrix = + Kokkos::create_staticcrsgraph<Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>>("connectivity_matrix", + M.graphVector()); + + // This is a bit crappy but it is the price to pay to avoid + m_host_matrix.entries = host_matrix.entries; + m_host_matrix.row_map = host_matrix.row_map; + m_host_matrix.row_block_offsets = host_matrix.row_block_offsets; + } + m_values = M.valueArray(); } ~CRSMatrix() = default; }; diff --git a/src/algebra/LinearSolver.cpp b/src/algebra/LinearSolver.cpp index 6a6beaa05..e11e60ae0 100644 --- a/src/algebra/LinearSolver.cpp +++ b/src/algebra/LinearSolver.cpp @@ -173,17 +173,17 @@ struct LinearSolver::Internals MatSetType(petscMat, MATAIJ); Array<PetscScalar> values = copy(A.values()); - auto hm = A.hostMatrix(); - Array<PetscInt> row_indices{hm.row_map.size()}; - for (size_t i = 0; i < hm.row_map.size(); ++i) { - row_indices[i] = hm.row_map[i]; + const auto A_row_indices = A.rowIndices(); + Array<PetscInt> row_indices{A_row_indices.size()}; + for (size_t i = 0; i < row_indices.size(); ++i) { + row_indices[i] = A_row_indices[i]; } Array<PetscInt> column_indices{values.size()}; size_t l = 0; for (size_t i = 0; i < A.numberOfRows(); ++i) { - const auto row_i = hm.rowConst(i); + const auto row_i = A.row(i); for (size_t j = 0; j < row_i.length; ++j) { column_indices[l++] = row_i.colidx(j); } diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp index 7a7a943d1..6c8584796 100644 --- a/tests/test_CRSMatrix.cpp +++ b/tests/test_CRSMatrix.cpp @@ -94,7 +94,7 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[4] == -2); } - SECTION("matrix vector product (complet)") + SECTION("matrix vector product (complete)") { SparseMatrixDescriptor<int, uint8_t> S{4}; S(0, 0) = 1; @@ -129,6 +129,60 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[3] == 150); } + SECTION("check values") + { + SparseMatrixDescriptor<int, uint8_t> S{4}; + S(3, 0) = 13; + S(0, 0) = 1; + S(0, 1) = 2; + S(1, 1) = 6; + S(1, 2) = 7; + S(2, 2) = 11; + S(3, 2) = 15; + S(2, 0) = 9; + S(3, 3) = 16; + S(2, 3) = 12; + S(0, 3) = 4; + S(1, 0) = 5; + S(2, 1) = 10; + + CRSMatrix<int, uint8_t> A{S}; + + auto values = A.values(); + REQUIRE(values.size() == 13); + REQUIRE(values[0] == 1); + REQUIRE(values[1] == 2); + REQUIRE(values[2] == 4); + REQUIRE(values[3] == 5); + REQUIRE(values[4] == 6); + REQUIRE(values[5] == 7); + REQUIRE(values[6] == 9); + REQUIRE(values[7] == 10); + REQUIRE(values[8] == 11); + REQUIRE(values[9] == 12); + REQUIRE(values[10] == 13); + REQUIRE(values[11] == 15); + REQUIRE(values[12] == 16); + + auto row_indices = A.rowIndices(); + + REQUIRE(row_indices.size() == 5); + + REQUIRE(A.row(0).colidx(0) == 0); + REQUIRE(A.row(0).colidx(1) == 1); + REQUIRE(A.row(0).colidx(2) == 3); + REQUIRE(A.row(1).colidx(0) == 0); + REQUIRE(A.row(1).colidx(1) == 1); + REQUIRE(A.row(1).colidx(2) == 2); + REQUIRE(A.row(2).colidx(0) == 0); + REQUIRE(A.row(2).colidx(1) == 1); + REQUIRE(A.row(2).colidx(2) == 2); + REQUIRE(A.row(2).colidx(3) == 3); + REQUIRE(A.row(3).colidx(0) == 0); + REQUIRE(A.row(3).colidx(1) == 2); + REQUIRE(A.row(3).colidx(2) == 3); + } + #ifndef NDEBUG SECTION("incompatible runtime matrix/vector product") { -- GitLab