diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp index 8e0bbd4db39f20977afa3725177f728d7094710d..72dabc16acf19c4217e007cde8d0b10060234344 100644 --- a/src/algebra/TinyMatrix.hpp +++ b/src/algebra/TinyMatrix.hpp @@ -567,4 +567,29 @@ inverse(const TinyMatrix<3, 3, T>& A) return A_cofactors_T *= 1. / determinent; } +template <size_t M, size_t N, typename T> +PUGS_INLINE T +scalarProduct(const TinyMatrix<M, N, T>& A, const TinyMatrix<M, N, T>& B) +{ + static_assert(std::is_arithmetic<T>::value, "scalarProduct is not defined for non-arithmetic types"); + + T scalarProduct = 0; + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + scalarProduct += A(i, j) * B(i, j); + } + } + return scalarProduct; +} + +template <size_t M, size_t N, typename T> +PUGS_INLINE T +l2Norm(const TinyMatrix<M, N, T>& A) +{ + static_assert(std::is_arithmetic<T>::value, "norm is not defined for non-arithmetic types"); + + T norm = std::sqrt(scalarProduct(A, A)); + return norm; +} + #endif // TINYMATRIX_HPP diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp index 03d3bbe5dc0ad69655cda5d3ae764d8e770b0d60..0af7e59313e28e052c90bf2a704e138953248eb1 100644 --- a/tests/test_TinyMatrix.cpp +++ b/tests/test_TinyMatrix.cpp @@ -291,6 +291,20 @@ TEST_CASE("TinyMatrix", "[algebra]") REQUIRE(Catch::Detail::stringify(TinyMatrix<1, 1, int>(7)) == "[[7]]"); } + SECTION("checking scalarProduct") + { + TinyMatrix<3, 4, int> B(0, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1); + // TinyMatrix<3, 4, int> A(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12); + + REQUIRE(scalarProduct(A, B) == -7); + } + SECTION("checking norm") + { + TinyMatrix<3, 4, int> B(0, 0, 0, -1, 1, -1, 1, -1, 1, -1, 1, -1); + // TinyMatrix<3, 4, int> A(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12); + + REQUIRE(norm(B) == 3); + } #ifndef NDEBUG SECTION("output with signaling NaN") { @@ -341,5 +355,15 @@ TEST_CASE("TinyMatrix", "[algebra]") } } } + SECTION("checking for bad initialization") + { + TinyMatrix<3, 4, int> B; + + for (size_t i = 0; i < B.numberOfRows(); ++i) { + for (size_t j = 0; j < B.numberOfColumns(); ++j) { + REQUIRE(B(i, j) == std::numeric_limits<int>::max() / 2); + } + } + } #endif // NDEBUG }