From 4ea04ec484914b5a30b6b63eb21e18ff1eae5532 Mon Sep 17 00:00:00 2001
From: labourasse <labourassee@gmail.com>
Date: Thu, 25 May 2023 08:56:09 +0200
Subject: [PATCH] add double contraction method (scalarProduct) and norm
 (l2Norm) for Tinymatrix

---
 src/algebra/TinyMatrix.hpp | 25 +++++++++++++++++++++++++
 tests/test_TinyMatrix.cpp  | 24 ++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/src/algebra/TinyMatrix.hpp b/src/algebra/TinyMatrix.hpp
index 8e0bbd4db..72dabc16a 100644
--- a/src/algebra/TinyMatrix.hpp
+++ b/src/algebra/TinyMatrix.hpp
@@ -567,4 +567,29 @@ inverse(const TinyMatrix<3, 3, T>& A)
   return A_cofactors_T *= 1. / determinent;
 }
 
+template <size_t M, size_t N, typename T>
+PUGS_INLINE T
+scalarProduct(const TinyMatrix<M, N, T>& A, const TinyMatrix<M, N, T>& B)
+{
+  static_assert(std::is_arithmetic<T>::value, "scalarProduct is not defined for non-arithmetic types");
+
+  T scalarProduct = 0;
+  for (size_t i = 0; i < M; ++i) {
+    for (size_t j = 0; j < N; ++j) {
+      scalarProduct += A(i, j) * B(i, j);
+    }
+  }
+  return scalarProduct;
+}
+
+template <size_t M, size_t N, typename T>
+PUGS_INLINE T
+l2Norm(const TinyMatrix<M, N, T>& A)
+{
+  static_assert(std::is_arithmetic<T>::value, "norm is not defined for non-arithmetic types");
+
+  T norm = std::sqrt(scalarProduct(A, A));
+  return norm;
+}
+
 #endif   // TINYMATRIX_HPP
diff --git a/tests/test_TinyMatrix.cpp b/tests/test_TinyMatrix.cpp
index 03d3bbe5d..0af7e5931 100644
--- a/tests/test_TinyMatrix.cpp
+++ b/tests/test_TinyMatrix.cpp
@@ -291,6 +291,20 @@ TEST_CASE("TinyMatrix", "[algebra]")
     REQUIRE(Catch::Detail::stringify(TinyMatrix<1, 1, int>(7)) == "[[7]]");
   }
 
+  SECTION("checking scalarProduct")
+  {
+    TinyMatrix<3, 4, int> B(0, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1);
+    //  TinyMatrix<3, 4, int> A(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
+
+    REQUIRE(scalarProduct(A, B) == -7);
+  }
+  SECTION("checking norm")
+  {
+    TinyMatrix<3, 4, int> B(0, 0, 0, -1, 1, -1, 1, -1, 1, -1, 1, -1);
+    //  TinyMatrix<3, 4, int> A(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12);
+
+    REQUIRE(norm(B) == 3);
+  }
 #ifndef NDEBUG
   SECTION("output with signaling NaN")
   {
@@ -341,5 +355,15 @@ TEST_CASE("TinyMatrix", "[algebra]")
       }
     }
   }
+  SECTION("checking for bad initialization")
+  {
+    TinyMatrix<3, 4, int> B;
+
+    for (size_t i = 0; i < B.numberOfRows(); ++i) {
+      for (size_t j = 0; j < B.numberOfColumns(); ++j) {
+        REQUIRE(B(i, j) == std::numeric_limits<int>::max() / 2);
+      }
+    }
+  }
 #endif   // NDEBUG
 }
-- 
GitLab