diff --git a/src/algebra/CRSMatrix.hpp b/src/algebra/CRSMatrix.hpp index d9b03ac93fb2f1c4baeff98c90daac17547ff7cf..98297d759675fdb7fdbf8be1e638f19cba4cc562 100644 --- a/src/algebra/CRSMatrix.hpp +++ b/src/algebra/CRSMatrix.hpp @@ -17,7 +17,7 @@ class CRSMatrix using MutableDataType = std::remove_const_t<DataType>; private: - using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>; + using HostMatrix = Kokkos::StaticCrsGraph<const IndexType, Kokkos::HostSpace>; HostMatrix m_host_matrix; Array<const DataType> m_values; @@ -37,9 +37,15 @@ class CRSMatrix } auto - hostMatrix() const + rowIndices() const { - return m_host_matrix; + return encapsulate(m_host_matrix.row_map); + } + + auto + row(size_t i) const + { + return m_host_matrix.rowConst(i); } template <typename DataType2> @@ -70,8 +76,17 @@ class CRSMatrix CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M) { - m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector()); - m_values = M.valueArray(); + { + auto host_matrix = + Kokkos::create_staticcrsgraph<Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>>("connectivity_matrix", + M.graphVector()); + + // This is a bit crappy but it is the price to pay to avoid + m_host_matrix.entries = host_matrix.entries; + m_host_matrix.row_map = host_matrix.row_map; + m_host_matrix.row_block_offsets = host_matrix.row_block_offsets; + } + m_values = M.valueArray(); } ~CRSMatrix() = default; }; diff --git a/src/algebra/LinearSolver.cpp b/src/algebra/LinearSolver.cpp index 6a6beaa05bd3ac46b1a81f6e8645721003fe5577..e11e60ae0c8b72ef111e1bc64c69ffe931d98621 100644 --- a/src/algebra/LinearSolver.cpp +++ b/src/algebra/LinearSolver.cpp @@ -173,17 +173,17 @@ struct LinearSolver::Internals MatSetType(petscMat, MATAIJ); Array<PetscScalar> values = copy(A.values()); - auto hm = A.hostMatrix(); - Array<PetscInt> row_indices{hm.row_map.size()}; - for (size_t i = 0; i < hm.row_map.size(); ++i) { - row_indices[i] = hm.row_map[i]; + const auto A_row_indices = A.rowIndices(); + Array<PetscInt> row_indices{A_row_indices.size()}; + for (size_t i = 0; i < row_indices.size(); ++i) { + row_indices[i] = A_row_indices[i]; } Array<PetscInt> column_indices{values.size()}; size_t l = 0; for (size_t i = 0; i < A.numberOfRows(); ++i) { - const auto row_i = hm.rowConst(i); + const auto row_i = A.row(i); for (size_t j = 0; j < row_i.length; ++j) { column_indices[l++] = row_i.colidx(j); } diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp index 7a7a943d14c943f5a315544fc6eab947ffc28463..6c8584796b8c71991941870c723cc23617c0c2ac 100644 --- a/tests/test_CRSMatrix.cpp +++ b/tests/test_CRSMatrix.cpp @@ -94,7 +94,7 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[4] == -2); } - SECTION("matrix vector product (complet)") + SECTION("matrix vector product (complete)") { SparseMatrixDescriptor<int, uint8_t> S{4}; S(0, 0) = 1; @@ -129,6 +129,60 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[3] == 150); } + SECTION("check values") + { + SparseMatrixDescriptor<int, uint8_t> S{4}; + S(3, 0) = 13; + S(0, 0) = 1; + S(0, 1) = 2; + S(1, 1) = 6; + S(1, 2) = 7; + S(2, 2) = 11; + S(3, 2) = 15; + S(2, 0) = 9; + S(3, 3) = 16; + S(2, 3) = 12; + S(0, 3) = 4; + S(1, 0) = 5; + S(2, 1) = 10; + + CRSMatrix<int, uint8_t> A{S}; + + auto values = A.values(); + REQUIRE(values.size() == 13); + REQUIRE(values[0] == 1); + REQUIRE(values[1] == 2); + REQUIRE(values[2] == 4); + REQUIRE(values[3] == 5); + REQUIRE(values[4] == 6); + REQUIRE(values[5] == 7); + REQUIRE(values[6] == 9); + REQUIRE(values[7] == 10); + REQUIRE(values[8] == 11); + REQUIRE(values[9] == 12); + REQUIRE(values[10] == 13); + REQUIRE(values[11] == 15); + REQUIRE(values[12] == 16); + + auto row_indices = A.rowIndices(); + + REQUIRE(row_indices.size() == 5); + + REQUIRE(A.row(0).colidx(0) == 0); + REQUIRE(A.row(0).colidx(1) == 1); + REQUIRE(A.row(0).colidx(2) == 3); + REQUIRE(A.row(1).colidx(0) == 0); + REQUIRE(A.row(1).colidx(1) == 1); + REQUIRE(A.row(1).colidx(2) == 2); + REQUIRE(A.row(2).colidx(0) == 0); + REQUIRE(A.row(2).colidx(1) == 1); + REQUIRE(A.row(2).colidx(2) == 2); + REQUIRE(A.row(2).colidx(3) == 3); + REQUIRE(A.row(3).colidx(0) == 0); + REQUIRE(A.row(3).colidx(1) == 2); + REQUIRE(A.row(3).colidx(2) == 3); + } + #ifndef NDEBUG SECTION("incompatible runtime matrix/vector product") {