From e7da47a7716088cd51c46f87019413c794bec7b9 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 28 Jul 2020 16:15:52 +0200 Subject: [PATCH] Add a few missing assertions Note that some of them were already set by Kokkos (in debug mode), but it seems better to only rely on `pugs` definitions --- src/algebra/SparseMatrixDescriptor.hpp | 16 ++++++++++++---- src/mesh/ItemToItemMatrix.hpp | 12 +++++++++--- tests/test_SparseMatrixDescriptor.cpp | 26 ++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/algebra/SparseMatrixDescriptor.hpp b/src/algebra/SparseMatrixDescriptor.hpp index 3533a4017..100000366 100644 --- a/src/algebra/SparseMatrixDescriptor.hpp +++ b/src/algebra/SparseMatrixDescriptor.hpp @@ -28,14 +28,16 @@ class SparseMatrixDescriptor return m_id_value_map.size(); } - const DataType& operator[](const IndexType& j) const + const DataType& + operator[](const IndexType& j) const { auto i_value = m_id_value_map.find(j); Assert(i_value != m_id_value_map.end()); return i_value->second; } - DataType& operator[](const IndexType& j) + DataType& + operator[](const IndexType& j) { auto i_value = m_id_value_map.find(j); if (i_value != m_id_value_map.end()) { @@ -51,7 +53,7 @@ class SparseMatrixDescriptor operator<<(std::ostream& os, const SparseRowDescriptor& row) { for (auto [j, value] : row.m_id_value_map) { - os << ' ' << j << ':' << value; + os << ' ' << static_cast<size_t>(j) << ':' << value; } return os; } @@ -73,24 +75,30 @@ class SparseMatrixDescriptor SparseRowDescriptor& row(const IndexType i) { + Assert(i < m_row_array.size()); return m_row_array[i]; } const SparseRowDescriptor& row(const IndexType i) const { + Assert(i < m_row_array.size()); return m_row_array[i]; } DataType& operator()(const IndexType& i, const IndexType& j) { + Assert(i < m_row_array.size()); + Assert(j < m_row_array.size()); return m_row_array[i][j]; } const DataType& operator()(const IndexType& i, const IndexType& j) const { + Assert(i < m_row_array.size()); + Assert(j < m_row_array.size()); const auto& r = m_row_array[i]; // split to ensure const-ness of call return r[j]; } @@ -99,7 +107,7 @@ class SparseMatrixDescriptor operator<<(std::ostream& os, const SparseMatrixDescriptor& M) { for (IndexType i = 0; i < M.m_row_array.size(); ++i) { - os << i << " |" << M.m_row_array[i] << '\n'; + os << static_cast<size_t>(i) << " |" << M.m_row_array[i] << '\n'; } return os; } diff --git a/src/mesh/ItemToItemMatrix.hpp b/src/mesh/ItemToItemMatrix.hpp index e8d7071af..1cb20d3ee 100644 --- a/src/mesh/ItemToItemMatrix.hpp +++ b/src/mesh/ItemToItemMatrix.hpp @@ -29,8 +29,10 @@ class ItemToItemMatrix } PUGS_INLINE - TargetItemId operator[](size_t j) const + TargetItemId + operator[](size_t j) const { + Assert(j < m_row.length); return m_row(j); } @@ -73,15 +75,19 @@ class ItemToItemMatrix } PUGS_INLINE - auto operator[](const SourceItemId& source_id) const + auto + operator[](const SourceItemId& source_id) const { + Assert(source_id < m_connectivity_matrix.numRows()); using RowType = decltype(m_connectivity_matrix.rowConst(source_id)); return SubItemList<RowType>(m_connectivity_matrix.rowConst(source_id)); } template <typename IndexType> - PUGS_INLINE const auto& operator[](const IndexType& source_id) const + PUGS_INLINE const auto& + operator[](const IndexType& source_id) const { + Assert(source_id < m_connectivity_matrix.numRows()); static_assert(std::is_same_v<IndexType, SourceItemId>, "ItemToItemMatrix must be indexed using correct ItemId"); using RowType = decltype(m_connectivity_matrix.rowConst(source_id)); return SubItemList<RowType>(m_connectivity_matrix.rowConst(source_id)); diff --git a/tests/test_SparseMatrixDescriptor.cpp b/tests/test_SparseMatrixDescriptor.cpp index 4e8b90134..cd66380b1 100644 --- a/tests/test_SparseMatrixDescriptor.cpp +++ b/tests/test_SparseMatrixDescriptor.cpp @@ -168,4 +168,30 @@ TEST_CASE("SparseMatrixDescriptor", "[algebra]") REQUIRE(value_array[6] == 1); REQUIRE(value_array[7] == -2); } + + SECTION("output") + { + SparseMatrixDescriptor<int, uint8_t> S{5}; + S(0, 2) = 3; + S(1, 2) = 11; + S(1, 1) = 1; + S(0, 2) += 2; + S(3, 3) = 5; + S(3, 1) = -3; + S(4, 1) = 1; + S(2, 2) = 4; + S(4, 4) = -2; + + std::ostringstream output; + output << '\n' << S; + + std::string expected_output = R"( +0 | 2:5 +1 | 1:1 2:11 +2 | 2:4 +3 | 1:-3 3:5 +4 | 1:1 4:-2 +)"; + REQUIRE(output.str() == expected_output); + } } -- GitLab