From ba39b67bc0fa93424ff6e89db6bb829925c95e5f Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Fri, 15 Nov 2024 00:49:38 +0100 Subject: [PATCH] Add doubleDot:R^dxd*R^dxd -> R to the language --- src/language/modules/MathModule.cpp | 12 +++++ tests/test_BuiltinFunctionProcessor.cpp | 42 ++++++++++++++-- tests/test_MathModule.cpp | 65 ++++++++++++++++++++++++- 3 files changed, 115 insertions(+), 4 deletions(-) diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp index af1dfeafb..9d0a44a7c 100644 --- a/src/language/modules/MathModule.cpp +++ b/src/language/modules/MathModule.cpp @@ -71,6 +71,18 @@ MathModule::MathModule() return dot(x, y); })); + this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<1> A, const TinyMatrix<1> B) -> double { + return doubleDot(A, B); + })); + + this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<2> A, const TinyMatrix<2> B) -> double { + return doubleDot(A, B); + })); + + this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<3>& A, const TinyMatrix<3>& B) -> double { + return doubleDot(A, B); + })); + this->_addBuiltinFunction("tensorProduct", std::function([](const TinyVector<1> x, const TinyVector<1> y) -> TinyMatrix<1> { return tensorProduct(x, y); diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp index f600bfa44..5556eca19 100644 --- a/tests/test_BuiltinFunctionProcessor.cpp +++ b/tests/test_BuiltinFunctionProcessor.cpp @@ -362,7 +362,43 @@ let s:R, s = dot(x,y); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", dot(TinyVector<3>{-2, 3, 4}, TinyVector<3>{4, 3, 5})); } - { // dot + { // double-dot + tested_function_set.insert("doubleDot:R^1x1*R^1x1"); + std::string_view data = R"( +import math; +let A1:R^1x1, A1 = [[-2]]; +let A2:R^1x1, A2 = [[4]]; +let s:R, s = doubleDot(A1,A2); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", double{-2 * 4}); + } + + { // double-dot + tested_function_set.insert("doubleDot:R^2x2*R^2x2"); + std::string_view data = R"( +import math; +let A1:R^2x2, A1 = [[-2, 3],[5,-2]]; +let A2:R^2x2, A2 = [[4, 3],[7,3]]; +let s:R, s = doubleDot(A1,A2); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", + doubleDot(TinyMatrix<2>{-2, 3, 5, -2}, TinyMatrix<2>{4, 3, 7, 3})); + } + + { // double-dot + tested_function_set.insert("doubleDot:R^3x3*R^3x3"); + std::string_view data = R"( +import math; +let A1:R^3x3, A1 = [[-2, 3, 4],[1,2,3],[6,3,2]]; +let A2:R^3x3, A2 = [[4, 3, 5],[2, 3, 1],[2, 6, 1]]; +let s:R, s = doubleDot(A1,A2); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", + doubleDot(TinyMatrix<3>{-2, 3, 4, 1, 2, 3, 6, 3, 2}, + TinyMatrix<3>{4, 3, 5, 2, 3, 1, 2, 6, 1})); + } + + { // tensor product tested_function_set.insert("tensorProduct:R^1*R^1"); std::string_view data = R"( import math; @@ -373,7 +409,7 @@ let s:R^1x1, s = tensorProduct(x,y); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", TinyMatrix<1>{-2 * 4}); } - { // dot + { // tensor product tested_function_set.insert("tensorProduct:R^2*R^2"); std::string_view data = R"( import math; @@ -384,7 +420,7 @@ let s:R^2x2, s = tensorProduct(x,y); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", tensorProduct(TinyVector<2>{-2, 3}, TinyVector<2>{4, 3})); } - { // dot + { // tensor product tested_function_set.insert("tensorProduct:R^3*R^3"); std::string_view data = R"( import math; diff --git a/tests/test_MathModule.cpp b/tests/test_MathModule.cpp index 89a480e1a..54befcd82 100644 --- a/tests/test_MathModule.cpp +++ b/tests/test_MathModule.cpp @@ -13,7 +13,7 @@ TEST_CASE("MathModule", "[language]") MathModule math_module; const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap(); - REQUIRE(name_builtin_function.size() == 45); + REQUIRE(name_builtin_function.size() == 48); SECTION("Z -> N") { @@ -458,6 +458,69 @@ TEST_CASE("MathModule", "[language]") } } + SECTION("(R^dxd, R^dxd) -> double") + { + SECTION("doubleDot:R^1x1*R^1x1") + { + TinyMatrix<1> arg0{3}; + TinyMatrix<1> arg1{2}; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("doubleDot:R^1x1*R^1x1"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + const double result = doubleDot(arg0, arg1); + REQUIRE(std::get<double>(result_variant) == result); + } + + SECTION("doubleDot:R^2x2*R^2x2") + { + TinyMatrix<2> arg0{+3, +2, // + -1, +4}; + TinyMatrix<2> arg1{-2, -5, // + +7, +1.3}; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("doubleDot:R^2x2*R^2x2"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + const double result = doubleDot(arg0, arg1); + REQUIRE(std::get<double>(result_variant) == result); + } + + SECTION("doubleDot:R^3x3*R^3x3") + { + TinyMatrix<3> arg0{+3, +2, +4, // + -1, +3, -6, // + +2, +5, +1}; + TinyMatrix<3> arg1{-2, +5, +2, // + +1, +7, -2, // + +7, -1, +3}; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("doubleDot:R^3x3*R^3x3"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + const double result = doubleDot(arg0, arg1); + REQUIRE(std::get<double>(result_variant) == result); + } + } + SECTION("(R^d, R^d) -> R^dxd") { SECTION("tensorProduct:R^1*R^1") -- GitLab