From ba39b67bc0fa93424ff6e89db6bb829925c95e5f Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 15 Nov 2024 00:49:38 +0100
Subject: [PATCH] Add doubleDot:R^dxd*R^dxd -> R to the language

---
 src/language/modules/MathModule.cpp     | 12 +++++
 tests/test_BuiltinFunctionProcessor.cpp | 42 ++++++++++++++--
 tests/test_MathModule.cpp               | 65 ++++++++++++++++++++++++-
 3 files changed, 115 insertions(+), 4 deletions(-)

diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp
index af1dfeafb..9d0a44a7c 100644
--- a/src/language/modules/MathModule.cpp
+++ b/src/language/modules/MathModule.cpp
@@ -71,6 +71,18 @@ MathModule::MathModule()
                               return dot(x, y);
                             }));
 
+  this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<1> A, const TinyMatrix<1> B) -> double {
+                              return doubleDot(A, B);
+                            }));
+
+  this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<2> A, const TinyMatrix<2> B) -> double {
+                              return doubleDot(A, B);
+                            }));
+
+  this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<3>& A, const TinyMatrix<3>& B) -> double {
+                              return doubleDot(A, B);
+                            }));
+
   this->_addBuiltinFunction("tensorProduct",
                             std::function([](const TinyVector<1> x, const TinyVector<1> y) -> TinyMatrix<1> {
                               return tensorProduct(x, y);
diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp
index f600bfa44..5556eca19 100644
--- a/tests/test_BuiltinFunctionProcessor.cpp
+++ b/tests/test_BuiltinFunctionProcessor.cpp
@@ -362,7 +362,43 @@ let s:R, s = dot(x,y);
       CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", dot(TinyVector<3>{-2, 3, 4}, TinyVector<3>{4, 3, 5}));
     }
 
-    {   // dot
+    {   // double-dot
+      tested_function_set.insert("doubleDot:R^1x1*R^1x1");
+      std::string_view data = R"(
+import math;
+let A1:R^1x1, A1 = [[-2]];
+let A2:R^1x1, A2 = [[4]];
+let s:R, s = doubleDot(A1,A2);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", double{-2 * 4});
+    }
+
+    {   // double-dot
+      tested_function_set.insert("doubleDot:R^2x2*R^2x2");
+      std::string_view data = R"(
+import math;
+let A1:R^2x2, A1 = [[-2, 3],[5,-2]];
+let A2:R^2x2, A2 = [[4, 3],[7,3]];
+let s:R, s = doubleDot(A1,A2);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s",
+                                               doubleDot(TinyMatrix<2>{-2, 3, 5, -2}, TinyMatrix<2>{4, 3, 7, 3}));
+    }
+
+    {   // double-dot
+      tested_function_set.insert("doubleDot:R^3x3*R^3x3");
+      std::string_view data = R"(
+import math;
+let A1:R^3x3, A1 = [[-2, 3, 4],[1,2,3],[6,3,2]];
+let A2:R^3x3, A2 = [[4, 3, 5],[2, 3, 1],[2, 6, 1]];
+let s:R, s = doubleDot(A1,A2);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s",
+                                               doubleDot(TinyMatrix<3>{-2, 3, 4, 1, 2, 3, 6, 3, 2},
+                                                         TinyMatrix<3>{4, 3, 5, 2, 3, 1, 2, 6, 1}));
+    }
+
+    {   // tensor product
       tested_function_set.insert("tensorProduct:R^1*R^1");
       std::string_view data = R"(
 import math;
@@ -373,7 +409,7 @@ let s:R^1x1, s = tensorProduct(x,y);
       CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", TinyMatrix<1>{-2 * 4});
     }
 
-    {   // dot
+    {   // tensor product
       tested_function_set.insert("tensorProduct:R^2*R^2");
       std::string_view data = R"(
 import math;
@@ -384,7 +420,7 @@ let s:R^2x2, s = tensorProduct(x,y);
       CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", tensorProduct(TinyVector<2>{-2, 3}, TinyVector<2>{4, 3}));
     }
 
-    {   // dot
+    {   // tensor product
       tested_function_set.insert("tensorProduct:R^3*R^3");
       std::string_view data = R"(
 import math;
diff --git a/tests/test_MathModule.cpp b/tests/test_MathModule.cpp
index 89a480e1a..54befcd82 100644
--- a/tests/test_MathModule.cpp
+++ b/tests/test_MathModule.cpp
@@ -13,7 +13,7 @@ TEST_CASE("MathModule", "[language]")
   MathModule math_module;
   const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap();
 
-  REQUIRE(name_builtin_function.size() == 45);
+  REQUIRE(name_builtin_function.size() == 48);
 
   SECTION("Z -> N")
   {
@@ -458,6 +458,69 @@ TEST_CASE("MathModule", "[language]")
     }
   }
 
+  SECTION("(R^dxd, R^dxd) -> double")
+  {
+    SECTION("doubleDot:R^1x1*R^1x1")
+    {
+      TinyMatrix<1> arg0{3};
+      TinyMatrix<1> arg1{2};
+
+      DataVariant arg0_variant = arg0;
+      DataVariant arg1_variant = arg1;
+
+      auto i_function = name_builtin_function.find("doubleDot:R^1x1*R^1x1");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      const double result = doubleDot(arg0, arg1);
+      REQUIRE(std::get<double>(result_variant) == result);
+    }
+
+    SECTION("doubleDot:R^2x2*R^2x2")
+    {
+      TinyMatrix<2> arg0{+3, +2,   //
+                         -1, +4};
+      TinyMatrix<2> arg1{-2, -5,   //
+                         +7, +1.3};
+
+      DataVariant arg0_variant = arg0;
+      DataVariant arg1_variant = arg1;
+
+      auto i_function = name_builtin_function.find("doubleDot:R^2x2*R^2x2");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      const double result = doubleDot(arg0, arg1);
+      REQUIRE(std::get<double>(result_variant) == result);
+    }
+
+    SECTION("doubleDot:R^3x3*R^3x3")
+    {
+      TinyMatrix<3> arg0{+3, +2, +4,   //
+                         -1, +3, -6,   //
+                         +2, +5, +1};
+      TinyMatrix<3> arg1{-2, +5, +2,   //
+                         +1, +7, -2,   //
+                         +7, -1, +3};
+
+      DataVariant arg0_variant = arg0;
+      DataVariant arg1_variant = arg1;
+
+      auto i_function = name_builtin_function.find("doubleDot:R^3x3*R^3x3");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      const double result = doubleDot(arg0, arg1);
+      REQUIRE(std::get<double>(result_variant) == result);
+    }
+  }
+
   SECTION("(R^d, R^d) -> R^dxd")
   {
     SECTION("tensorProduct:R^1*R^1")
-- 
GitLab