diff --git a/tests/test_CRSMatrix.cpp b/tests/test_CRSMatrix.cpp index 4933173e2fee4d5a38cf2e82517f5abf5fe4013c..4d555b2810467edc761a1758299887c2d851a1b7 100644 --- a/tests/test_CRSMatrix.cpp +++ b/tests/test_CRSMatrix.cpp @@ -7,126 +7,132 @@ template class SparseMatrixDescriptor<int, uint8_t>; TEST_CASE("CRSMatrix", "[algebra]") { - SECTION("CRSMatrix") + SECTION("matrix size") { - SECTION("matrix size") - { - SparseMatrixDescriptor<int, uint8_t> S{5}; - S(0, 2) = 3; - S(1, 2) = 11; - S(1, 1) = 1; - S(3, 0) = 5; - S(3, 1) = -3; - S(4, 1) = 1; - S(2, 2) = 4; - S(4, 3) = 2; - S(4, 4) = -2; - - CRSMatrix<int, uint8_t> A{S}; - - REQUIRE(A.numberOfRows() == S.numberOfRows()); - } - - SECTION("matrix vector product (simple)") - { - SparseMatrixDescriptor<int, uint8_t> S{5}; - S(0, 2) = 3; - S(1, 2) = 11; - S(1, 1) = 1; - S(3, 0) = 5; - S(3, 1) = -3; - S(4, 1) = 1; - S(2, 2) = 4; - S(4, 3) = 2; - S(4, 4) = -2; - - CRSMatrix<int, uint8_t> A{S}; - - Vector<int> x{A.numberOfRows()}; - x = 0; - x[0] = 1; - - Vector<int> y = A * x; - REQUIRE(y[0] == 0); - REQUIRE(y[1] == 0); - REQUIRE(y[2] == 0); - REQUIRE(y[3] == 5); - REQUIRE(y[4] == 0); - - x = 0; - x[1] = 2; - - y = A * x; - REQUIRE(y[0] == 0); - REQUIRE(y[1] == 2); - REQUIRE(y[2] == 0); - REQUIRE(y[3] == -6); - REQUIRE(y[4] == 2); - - x = 0; - x[2] = -1; - - y = A * x; - REQUIRE(y[0] == -3); - REQUIRE(y[1] == -11); - REQUIRE(y[2] == -4); - REQUIRE(y[3] == 0); - REQUIRE(y[4] == 0); - - x = 0; - x[3] = 3; - - y = A * x; - REQUIRE(y[0] == 0); - REQUIRE(y[1] == 0); - REQUIRE(y[2] == 0); - REQUIRE(y[3] == 0); - REQUIRE(y[4] == 6); - - x = 0; - x[4] = 1; - - y = A * x; - REQUIRE(y[0] == 0); - REQUIRE(y[1] == 0); - REQUIRE(y[2] == 0); - REQUIRE(y[3] == 0); - REQUIRE(y[4] == -2); - } - - SECTION("matrix vector product (complet)") - { - SparseMatrixDescriptor<int, uint8_t> S{4}; - S(0, 0) = 1; - S(0, 1) = 2; - S(0, 2) = 3; - S(0, 3) = 4; - S(1, 0) = 5; - S(1, 1) = 6; - S(1, 2) = 7; - S(1, 3) = 8; - S(2, 0) = 9; - S(2, 1) = 10; - S(2, 2) = 11; - S(2, 3) = 12; - S(3, 0) = 13; - S(3, 1) = 14; - S(3, 2) = 15; - S(3, 3) = 16; - - CRSMatrix<int, uint8_t> A{S}; - - Vector<int> x{A.numberOfRows()}; - x[0] = 1; - x[1] = 2; - x[2] = 3; - x[3] = 4; - - Vector<int> y = A * x; - REQUIRE(y[0] == 30); - REQUIRE(y[1] == 70); - REQUIRE(y[2] == 110); - REQUIRE(y[3] == 150); - } + SparseMatrixDescriptor<int, uint8_t> S{5}; + S(0, 2) = 3; + S(1, 2) = 11; + S(1, 1) = 1; + S(3, 0) = 5; + S(3, 1) = -3; + S(4, 1) = 1; + S(2, 2) = 4; + S(4, 3) = 2; + S(4, 4) = -2; + + CRSMatrix<int, uint8_t> A{S}; + + REQUIRE(A.numberOfRows() == S.numberOfRows()); } + + SECTION("matrix vector product (simple)") + { + SparseMatrixDescriptor<int, uint8_t> S{5}; + S(0, 2) = 3; + S(1, 2) = 11; + S(1, 1) = 1; + S(3, 0) = 5; + S(3, 1) = -3; + S(4, 1) = 1; + S(2, 2) = 4; + S(4, 3) = 2; + S(4, 4) = -2; + + CRSMatrix<int, uint8_t> A{S}; + + Vector<int> x{A.numberOfRows()}; + x = 0; + x[0] = 1; + + Vector<int> y = A * x; + REQUIRE(y[0] == 0); + REQUIRE(y[1] == 0); + REQUIRE(y[2] == 0); + REQUIRE(y[3] == 5); + REQUIRE(y[4] == 0); + + x = 0; + x[1] = 2; + + y = A * x; + REQUIRE(y[0] == 0); + REQUIRE(y[1] == 2); + REQUIRE(y[2] == 0); + REQUIRE(y[3] == -6); + REQUIRE(y[4] == 2); + + x = 0; + x[2] = -1; + + y = A * x; + REQUIRE(y[0] == -3); + REQUIRE(y[1] == -11); + REQUIRE(y[2] == -4); + REQUIRE(y[3] == 0); + REQUIRE(y[4] == 0); + + x = 0; + x[3] = 3; + + y = A * x; + REQUIRE(y[0] == 0); + REQUIRE(y[1] == 0); + REQUIRE(y[2] == 0); + REQUIRE(y[3] == 0); + REQUIRE(y[4] == 6); + + x = 0; + x[4] = 1; + + y = A * x; + REQUIRE(y[0] == 0); + REQUIRE(y[1] == 0); + REQUIRE(y[2] == 0); + REQUIRE(y[3] == 0); + REQUIRE(y[4] == -2); + } + + SECTION("matrix vector product (complet)") + { + SparseMatrixDescriptor<int, uint8_t> S{4}; + S(0, 0) = 1; + S(0, 1) = 2; + S(0, 2) = 3; + S(0, 3) = 4; + S(1, 0) = 5; + S(1, 1) = 6; + S(1, 2) = 7; + S(1, 3) = 8; + S(2, 0) = 9; + S(2, 1) = 10; + S(2, 2) = 11; + S(2, 3) = 12; + S(3, 0) = 13; + S(3, 1) = 14; + S(3, 2) = 15; + S(3, 3) = 16; + + CRSMatrix<int, uint8_t> A{S}; + + Vector<int> x{A.numberOfRows()}; + x[0] = 1; + x[1] = 2; + x[2] = 3; + x[3] = 4; + + Vector<int> y = A * x; + REQUIRE(y[0] == 30); + REQUIRE(y[1] == 70); + REQUIRE(y[2] == 110); + REQUIRE(y[3] == 150); + } + +#ifndef NDEBUG + SECTION("incompatible runtime matrix/vector product") + { + CRSMatrix<int, uint8_t> A{SparseMatrixDescriptor<int, uint8_t>{4}}; + Vector<int> x{2}; + REQUIRE_THROWS_AS(A * x, AssertError); + } +#endif // NDEBUG }