#include <catch2/catch_test_macros.hpp> #include <catch2/matchers/catch_matchers_all.hpp> #include <utils/PugsAssert.hpp> #include <algebra/DenseMatrix.hpp> #include <algebra/Vector.hpp> // Instantiate to ensure full coverage is performed template class DenseMatrix<int>; // clazy:excludeall=non-pod-global-static TEST_CASE("DenseMatrix", "[algebra]") { SECTION("size") { DenseMatrix<int> A{2, 3}; REQUIRE(A.nbRows() == 2); REQUIRE(A.nbColumns() == 3); } SECTION("write access") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; REQUIRE(A(0, 0) == 0); REQUIRE(A(0, 1) == 1); REQUIRE(A(0, 2) == 2); REQUIRE(A(1, 0) == 3); REQUIRE(A(1, 1) == 4); REQUIRE(A(1, 2) == 5); DenseMatrix<const int> const_A = A; REQUIRE(const_A(0, 0) == 0); REQUIRE(const_A(0, 1) == 1); REQUIRE(const_A(0, 2) == 2); REQUIRE(const_A(1, 0) == 3); REQUIRE(const_A(1, 1) == 4); REQUIRE(const_A(1, 2) == 5); } SECTION("fill") { DenseMatrix<int> A{2, 3}; A.fill(2); REQUIRE(A(0, 0) == 2); REQUIRE(A(0, 1) == 2); REQUIRE(A(0, 2) == 2); REQUIRE(A(1, 0) == 2); REQUIRE(A(1, 1) == 2); REQUIRE(A(1, 2) == 2); } SECTION("copy constructor (shallow)") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; const DenseMatrix<int> B = A; REQUIRE(B(0, 0) == 0); REQUIRE(B(0, 1) == 1); REQUIRE(B(0, 2) == 2); REQUIRE(B(1, 0) == 3); REQUIRE(B(1, 1) == 4); REQUIRE(B(1, 2) == 5); } SECTION("copy (deep)") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; const DenseMatrix<int> B = copy(A); A(0, 0) = 10; A(0, 1) = 11; A(0, 2) = 12; A(1, 0) = 13; A(1, 1) = 14; A(1, 2) = 15; REQUIRE(B(0, 0) == 0); REQUIRE(B(0, 1) == 1); REQUIRE(B(0, 2) == 2); REQUIRE(B(1, 0) == 3); REQUIRE(B(1, 1) == 4); REQUIRE(B(1, 2) == 5); REQUIRE(A(0, 0) == 10); REQUIRE(A(0, 1) == 11); REQUIRE(A(0, 2) == 12); REQUIRE(A(1, 0) == 13); REQUIRE(A(1, 1) == 14); REQUIRE(A(1, 2) == 15); } SECTION("self scalar multiplication") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; A *= 2; REQUIRE(A(0, 0) == 0); REQUIRE(A(0, 1) == 2); REQUIRE(A(0, 2) == 4); REQUIRE(A(1, 0) == 6); REQUIRE(A(1, 1) == 8); REQUIRE(A(1, 2) == 10); } SECTION("left scalar multiplication") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; const DenseMatrix<int> B = 2 * A; REQUIRE(B(0, 0) == 0); REQUIRE(B(0, 1) == 2); REQUIRE(B(0, 2) == 4); REQUIRE(B(1, 0) == 6); REQUIRE(B(1, 1) == 8); REQUIRE(B(1, 2) == 10); } SECTION("product matrix vector") { DenseMatrix<int> A{2, 3}; A(0, 0) = 6; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; Vector<int> x{3}; x[0] = 7; x[1] = 3; x[2] = 4; Vector y = A * x; REQUIRE(y[0] == 53); REQUIRE(y[1] == 53); } SECTION("self scalar division") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 2; A(0, 2) = 4; A(1, 0) = 6; A(1, 1) = 8; A(1, 2) = 10; A /= 2; REQUIRE(A(0, 0) == 0); REQUIRE(A(0, 1) == 1); REQUIRE(A(0, 2) == 2); REQUIRE(A(1, 0) == 3); REQUIRE(A(1, 1) == 4); REQUIRE(A(1, 2) == 5); } SECTION("self minus") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; DenseMatrix<int> B{2, 3}; B(0, 0) = 5; B(0, 1) = 6; B(0, 2) = 4; B(1, 0) = 2; B(1, 1) = 1; B(1, 2) = 3; A -= B; REQUIRE(A(0, 0) == -5); REQUIRE(A(0, 1) == -5); REQUIRE(A(0, 2) == -2); REQUIRE(A(1, 0) == 1); REQUIRE(A(1, 1) == 3); REQUIRE(A(1, 2) == 2); } SECTION("self sum") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; DenseMatrix<int> B{2, 3}; B(0, 0) = 5; B(0, 1) = 6; B(0, 2) = 4; B(1, 0) = 2; B(1, 1) = 1; B(1, 2) = 3; A += B; REQUIRE(A(0, 0) == 5); REQUIRE(A(0, 1) == 7); REQUIRE(A(0, 2) == 6); REQUIRE(A(1, 0) == 5); REQUIRE(A(1, 1) == 5); REQUIRE(A(1, 2) == 8); } SECTION("sum") { DenseMatrix<int> A{2, 3}; A(0, 0) = 6; A(0, 1) = 5; A(0, 2) = 4; A(1, 0) = 3; A(1, 1) = 2; A(1, 2) = 1; DenseMatrix<int> B{2, 3}; B(0, 0) = 0; B(0, 1) = 1; B(0, 2) = 2; B(1, 0) = 3; B(1, 1) = 4; B(1, 2) = 5; DenseMatrix C = A + B; REQUIRE(C(0, 0) == 6); REQUIRE(C(0, 1) == 6); REQUIRE(C(0, 2) == 6); REQUIRE(C(1, 0) == 6); REQUIRE(C(1, 1) == 6); REQUIRE(C(1, 2) == 6); } SECTION("difference") { DenseMatrix<int> A{2, 3}; A(0, 0) = 6; A(0, 1) = 5; A(0, 2) = 4; A(1, 0) = 3; A(1, 1) = 2; A(1, 2) = 1; DenseMatrix<int> B{2, 3}; B(0, 0) = 0; B(0, 1) = 1; B(0, 2) = 2; B(1, 0) = 3; B(1, 1) = 4; B(1, 2) = 5; DenseMatrix C = A - B; REQUIRE(C(0, 0) == 6); REQUIRE(C(0, 1) == 4); REQUIRE(C(0, 2) == 2); REQUIRE(C(1, 0) == 0); REQUIRE(C(1, 1) == -2); REQUIRE(C(1, 2) == -4); } SECTION("transpose") { DenseMatrix<int> A{2, 3}; A(0, 0) = 0; A(0, 1) = 1; A(0, 2) = 2; A(1, 0) = 3; A(1, 1) = 4; A(1, 2) = 5; DenseMatrix B = transpose(A); REQUIRE(B(0, 0) == 0); REQUIRE(B(0, 1) == 3); REQUIRE(B(1, 0) == 1); REQUIRE(B(1, 1) == 4); REQUIRE(B(2, 0) == 2); REQUIRE(B(2, 1) == 5); } SECTION("product matrix vector") { DenseMatrix<int> A{2, 3}; A(0, 0) = 1; A(0, 1) = 2; A(0, 2) = 3; A(1, 0) = 4; A(1, 1) = 5; A(1, 2) = 6; DenseMatrix<int> B{3, 2}; B(0, 0) = 2; B(0, 1) = 8; B(1, 0) = 4; B(1, 1) = 9; B(2, 0) = 6; B(2, 1) = 10; DenseMatrix C = A * B; REQUIRE(C(0, 0) == 28); REQUIRE(C(0, 1) == 56); REQUIRE(C(1, 0) == 64); REQUIRE(C(1, 1) == 137); } }