diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp index 94590ff656784d018e9ca57991674da773d64568..12122a8749384abea699a61537b9ffc437c5629e 100644 --- a/src/algebra/TinyMatrix.hpp +++ b/src/algebra/TinyMatrix.hpp @@ -504,6 +504,21 @@ trace(const TinyMatrix<N, N, T>& A) return t; } +template <size_t N, typename T> +PUGS_INLINE T +frobeniusNorm(const TinyMatrix<N, N, T>& A) +{ + static_assert(std::is_arithmetic<T>::value, "norm is not defined for non-arithmetic types"); + + T t = 0; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < N; ++j) { + t += A(i, j) * A(i, j); + } + } + return std::sqrt(t); +} + template <size_t N, typename T> PUGS_INLINE constexpr TinyMatrix<N, N, T> inverse(const TinyMatrix<N, N, T>& A); diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp index 80159a021b54ebb162d339be2bbd5195aec815a0..f6d132bb7df158a0c2d510e798c1f5dd0ad4d9d7 100644 --- a/tests/test_TinyMatrix.cpp +++ b/tests/test_TinyMatrix.cpp @@ -209,6 +209,16 @@ TEST_CASE("TinyMatrix", "[algebra]") REQUIRE(trace(TinyMatrix<4>(1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 0, 0, 2, 2)) == 1 + 0 + 1 + 2); } + SECTION("checking for norm calculations") + { + REQUIRE(frobeniusNorm(TinyMatrix<1, 1, int>(6)) == 6); + REQUIRE(frobeniusNorm(TinyMatrix<2, 2>(5, 1, -3, 6)) == Catch::Approx(std::sqrt(25 + 1 + 9 + 36)).epsilon(1E-14)); + REQUIRE(frobeniusNorm(TinyMatrix<4>(1, 2.3, 7, -6.2, 3, 4, 9, 1, 4.1, 5, 2, -3, 2, 27, 3, 17.5)) == + Catch::Approx(std::sqrt(1 + 2.3 * 2.3 + 7 * 7 + 6.2 * 6.2 + 9 + 16 + 81 + 1 + 4.1 * 4.1 + 25 + 4 + 9 + 4 + + 27 * 27 + 9 + 17.5 * 17.5)) + .epsilon(1E-14)); + } + SECTION("checking for inverse calculations") { {