diff --git a/src/algebra/CRSMatrix.hpp b/src/algebra/CRSMatrix.hpp index 1f7fe37e571447b104b81ff61cf38dcf2144d897..4b88cd80c1e0c55b8a725f84106ddf13a283eb2f 100644 --- a/src/algebra/CRSMatrix.hpp +++ b/src/algebra/CRSMatrix.hpp @@ -27,6 +27,117 @@ class CRSMatrix Array<const DataType> m_values; Array<const IndexType> m_column_indices; + template <typename DataType2, typename BinOp> + CRSMatrix + _binOp(const CRSMatrix& A, const CRSMatrix<DataType2, IndexType>& B) const + { + static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), + "Cannot add matrices of different types"); + Assert(A.numberOfRows() == B.numberOfRows() and A.numberOfColumns() == B.numberOfColumns(), + "cannot apply inner binary operator on matrices of different sizes"); + + Array<int> non_zeros(A.m_nb_rows); + for (IndexType i_row = 0; i_row < A.m_nb_rows; ++i_row) { + const auto A_row_begin = A.m_row_map[i_row]; + const auto A_row_end = A.m_row_map[i_row + 1]; + const auto B_row_begin = B.m_row_map[i_row]; + const auto B_row_end = B.m_row_map[i_row + 1]; + IndexType i_row_A = A_row_begin; + IndexType i_row_B = B_row_begin; + + int row_nb_columns = 0; + + while (i_row_A < A_row_end or i_row_B < B_row_end) { + const IndexType A_column_idx = [&] { + if (i_row_A < A_row_end) { + return A.m_column_indices[i_row_A]; + } else { + return std::numeric_limits<IndexType>::max(); + } + }(); + + const IndexType B_column_idx = [&] { + if (i_row_B < B_row_end) { + return B.m_column_indices[i_row_B]; + } else { + return std::numeric_limits<IndexType>::max(); + } + }(); + + if (A_column_idx == B_column_idx) { + ++row_nb_columns; + ++i_row_A; + ++i_row_B; + } else if (A_column_idx < B_column_idx) { + ++row_nb_columns; + ++i_row_A; + } else { + Assert(B_column_idx < A_column_idx); + ++row_nb_columns; + ++i_row_B; + } + } + non_zeros[i_row] = row_nb_columns; + } + + Array<IndexType> row_map(A.m_nb_rows + 1); + row_map[0] = 0; + for (IndexType i = 0; i < A.m_nb_rows; ++i) { + row_map[i + 1] = row_map[i] + non_zeros[i]; + } + + const IndexType nb_values = row_map[row_map.size() - 1]; + Array<DataType> values(nb_values); + Array<IndexType> column_indices(nb_values); + + IndexType i_value = 0; + for (IndexType i_row = 0; i_row < A.m_nb_rows; ++i_row) { + const auto A_row_begin = A.m_row_map[i_row]; + const auto A_row_end = A.m_row_map[i_row + 1]; + const auto B_row_begin = B.m_row_map[i_row]; + const auto B_row_end = B.m_row_map[i_row + 1]; + IndexType i_row_A = A_row_begin; + IndexType i_row_B = B_row_begin; + + while (i_row_A < A_row_end or i_row_B < B_row_end) { + const IndexType A_column_idx = [&] { + if (i_row_A < A_row_end) { + return A.m_column_indices[i_row_A]; + } else { + return std::numeric_limits<IndexType>::max(); + } + }(); + + const IndexType B_column_idx = [&] { + if (i_row_B < B_row_end) { + return B.m_column_indices[i_row_B]; + } else { + return std::numeric_limits<IndexType>::max(); + } + }(); + + if (A_column_idx == B_column_idx) { + column_indices[i_value] = A_column_idx; + values[i_value] = BinOp()(A.m_values[i_row_A], B.m_values[i_row_B]); + ++i_row_A; + ++i_row_B; + } else if (A_column_idx < B_column_idx) { + column_indices[i_value] = A_column_idx; + values[i_value] = A.m_values[i_row_A]; + ++i_row_A; + } else { + Assert(B_column_idx < A_column_idx); + column_indices[i_value] = B_column_idx; + values[i_value] = B.m_values[i_row_B]; + ++i_row_B; + } + ++i_value; + } + } + + return CRSMatrix(A.m_nb_rows, A.m_nb_columns, row_map, values, column_indices); + } + public: PUGS_INLINE bool @@ -72,9 +183,9 @@ class CRSMatrix operator*(const Vector<DataType2>& x) const { static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), - "Cannot multiply matrix and vector of different type"); + "Cannot multiply matrix and vector of different types"); - Assert(x.size() - m_nb_columns == 0); + Assert(x.size() - m_nb_columns == 0, "cannot compute product: incompatible matrix and vector sizes"); Vector<MutableDataType> Ax(m_nb_rows); @@ -92,6 +203,20 @@ class CRSMatrix return Ax; } + template <typename DataType2> + CRSMatrix<DataType> + operator+(const CRSMatrix<DataType2>& B) const + { + return this->_binOp<DataType2, std::plus<DataType>>(*this, B); + } + + template <typename DataType2> + CRSMatrix<DataType> + operator-(const CRSMatrix<DataType2>& B) const + { + return this->_binOp<DataType2, std::minus<DataType>>(*this, B); + } + friend PUGS_INLINE std::ostream& operator<<(std::ostream& os, const CRSMatrix& A) { diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp index d0d9476ab0e47ff90874391c63f39a6321adbc3b..f4bce5a429a0b88ee0c6ce2c58661c3874692e42 100644 --- a/tests/test_CRSMatrix.cpp +++ b/tests/test_CRSMatrix.cpp @@ -165,6 +165,177 @@ TEST_CASE("CRSMatrix", "[algebra]") REQUIRE(y[3] == 150); } + SECTION("square matrices sum") + { + Array<int> non_zeros_DA{4}; + non_zeros_DA.fill(2); + CRSMatrixDescriptor<int> DA{4, 4, non_zeros_DA}; + DA(0, 0) = 1; + DA(0, 1) = 2; + DA(1, 0) = 5; + DA(1, 1) = 6; + DA(2, 0) = 9; + DA(2, 3) = 12; + DA(3, 0) = 13; + DA(3, 3) = 16; + + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> non_zeros_DB{4}; + non_zeros_DB.fill(2); + CRSMatrixDescriptor<int> DB{4, 4, non_zeros_DB}; + DB(0, 0) = 1; + DB(0, 2) = 2; + DB(1, 0) = 5; + DB(1, 2) = 6; + DB(2, 0) = 9; + DB(2, 2) = 12; + DB(3, 0) = 13; + DB(3, 3) = 16; + + CRSMatrix<int> B{DB.getCRSMatrix()}; + + std::ostringstream ost; + ost << A + B; + + std::string ref = R"(0| 0:2 1:2 2:2 +1| 0:10 1:6 2:6 +2| 0:18 2:12 3:12 +3| 0:26 3:32 +)"; + + REQUIRE(ost.str() == ref); + } + + SECTION("square matrices difference") + { + Array<int> non_zeros_DA{4}; + non_zeros_DA.fill(2); + CRSMatrixDescriptor<int> DA{4, 4, non_zeros_DA}; + DA(0, 0) = 1; + DA(0, 1) = 2; + DA(1, 0) = 5; + DA(1, 1) = 6; + DA(2, 0) = 9; + DA(2, 3) = 12; + DA(3, 0) = 13; + DA(3, 3) = 16; + + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> non_zeros_DB{4}; + non_zeros_DB.fill(2); + CRSMatrixDescriptor<int> DB{4, 4, non_zeros_DB}; + DB(0, 0) = -1; + DB(0, 2) = 3; + DB(1, 0) = 7; + DB(1, 2) = 3; + DB(2, 0) = 7; + DB(2, 2) = 11; + DB(3, 0) = 3; + DB(3, 3) = 8; + + CRSMatrix<int> B{DB.getCRSMatrix()}; + + std::ostringstream ost; + ost << A - B; + + std::string ref = R"(0| 0:2 1:2 2:3 +1| 0:-2 1:6 2:3 +2| 0:2 2:-11 3:12 +3| 0:10 3:8 +)"; + + REQUIRE(ost.str() == ref); + } + + SECTION("rectangle matrices sum") + { + Array<int> non_zeros_DA{4}; + non_zeros_DA.fill(2); + CRSMatrixDescriptor<int> DA{4, 5, non_zeros_DA}; + DA(0, 0) = 1; + DA(0, 1) = 2; + DA(0, 4) = 2; + DA(1, 0) = 5; + DA(1, 1) = 6; + DA(2, 0) = 9; + DA(2, 3) = 12; + DA(3, 0) = 13; + DA(3, 3) = 16; + DA(3, 4) = 16; + + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> non_zeros_DB{4}; + non_zeros_DB.fill(2); + CRSMatrixDescriptor<int> DB{4, 5, non_zeros_DB}; + DB(0, 0) = 1; + DB(0, 2) = 2; + DB(1, 0) = 5; + DB(1, 2) = 6; + DB(1, 4) = 3; + DB(2, 0) = 9; + DB(2, 2) = 12; + DB(3, 0) = 13; + DB(3, 3) = 16; + + CRSMatrix<int> B{DB.getCRSMatrix()}; + + std::ostringstream ost; + ost << A + B; + + std::string ref = R"(0| 0:2 1:2 2:2 4:2 +1| 0:10 1:6 2:6 4:3 +2| 0:18 2:12 3:12 +3| 0:26 3:32 4:16 +)"; + + REQUIRE(ost.str() == ref); + } + + SECTION("rectangle matrices difference") + { + Array<int> non_zeros_DA{4}; + non_zeros_DA.fill(2); + CRSMatrixDescriptor<int> DA{4, 3, non_zeros_DA}; + DA(0, 0) = 1; + DA(0, 1) = 2; + DA(1, 0) = 5; + DA(1, 1) = 6; + DA(2, 0) = 9; + DA(2, 2) = 12; + DA(3, 0) = 13; + DA(3, 2) = 16; + + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> non_zeros_DB{4}; + non_zeros_DB.fill(2); + CRSMatrixDescriptor<int> DB{4, 3, non_zeros_DB}; + DB(0, 0) = -1; + DB(0, 2) = 3; + DB(1, 0) = 7; + DB(1, 2) = 3; + DB(2, 0) = 7; + DB(2, 2) = 11; + DB(3, 0) = 3; + DB(3, 1) = 8; + + CRSMatrix<int> B{DB.getCRSMatrix()}; + + std::ostringstream ost; + ost << A - B; + + std::string ref = R"(0| 0:2 1:2 2:3 +1| 0:-2 1:6 2:3 +2| 0:2 2:1 +3| 0:10 1:8 2:16 +)"; + + REQUIRE(ost.str() == ref); + } + SECTION("check values") { Array<int> non_zeros{4}; @@ -237,14 +408,45 @@ TEST_CASE("CRSMatrix", "[algebra]") } #ifndef NDEBUG - SECTION("incompatible runtime matrix/vector product") + SECTION("runtime incompatible matrix/vector product") { Array<int> non_zeros(2); non_zeros.fill(0); CRSMatrixDescriptor<int> S{2, 4, non_zeros}; CRSMatrix<int> A{S.getCRSMatrix()}; Vector<int> x{2}; - REQUIRE_THROWS_AS(A * x, AssertError); + REQUIRE_THROWS_WITH(A * x, "cannot compute product: incompatible matrix and vector sizes"); } + + SECTION("runtime incompatible matrices sum") + { + Array<int> A_non_zeros(2); + A_non_zeros.fill(0); + CRSMatrixDescriptor<int> DA{2, 4, A_non_zeros}; + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> B_non_zeros(3); + B_non_zeros.fill(0); + CRSMatrixDescriptor<int> DB{3, 4, B_non_zeros}; + CRSMatrix<int> B{DB.getCRSMatrix()}; + + REQUIRE_THROWS_WITH(A + B, "cannot apply inner binary operator on matrices of different sizes"); + } + + SECTION("runtime incompatible matrices difference") + { + Array<int> A_non_zeros(2); + A_non_zeros.fill(0); + CRSMatrixDescriptor<int> DA{2, 4, A_non_zeros}; + CRSMatrix<int> A{DA.getCRSMatrix()}; + + Array<int> B_non_zeros(2); + B_non_zeros.fill(0); + CRSMatrixDescriptor<int> DB{2, 3, B_non_zeros}; + CRSMatrix<int> B{DB.getCRSMatrix()}; + + REQUIRE_THROWS_WITH(A - B, "cannot apply inner binary operator on matrices of different sizes"); + } + #endif // NDEBUG }