diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp index aa961e4a0ae250581fad42b644f7d670ef4fdb32..fd38710a5b2232e12b54d69217880188ab292fc9 100644 --- a/src/language/modules/MathModule.cpp +++ b/src/language/modules/MathModule.cpp @@ -69,6 +69,19 @@ MathModule::MathModule() this->_addBuiltinFunction("round", std::make_shared<BuiltinFunctionEmbedder<int64_t(double)>>( [](double x) -> int64_t { return std::lround(x); })); + + this->_addBuiltinFunction("dot", + std::make_shared<BuiltinFunctionEmbedder<double(const TinyVector<1>, const TinyVector<1>)>>( + [](const TinyVector<1> x, const TinyVector<1> y) -> double { return (x, y); })); + + this->_addBuiltinFunction("dot", + std::make_shared<BuiltinFunctionEmbedder<double(const TinyVector<2>, const TinyVector<2>)>>( + [](const TinyVector<2> x, const TinyVector<2> y) -> double { return (x, y); })); + + this + ->_addBuiltinFunction("dot", + std::make_shared<BuiltinFunctionEmbedder<double(const TinyVector<3>&, const TinyVector<3>&)>>( + [](const TinyVector<3>& x, const TinyVector<3>& y) -> double { return (x, y); })); } void diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp index dda3becc1897f2e659ed105256bc0ec09271e46d..c36399857d5a9b2235111d1ca105559001edf69d 100644 --- a/tests/test_BuiltinFunctionProcessor.cpp +++ b/tests/test_BuiltinFunctionProcessor.cpp @@ -278,6 +278,39 @@ let z:Z, z = round(-1.2); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1}); } + { // dot + tested_function_set.insert("dot:R^1*R^1"); + std::string_view data = R"( +import math; +let x:R^1, x = -2; +let y:R^1, y = 4; +let s:R, s = dot(x,y); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", double{-2 * 4}); + } + + { // dot + tested_function_set.insert("dot:R^2*R^2"); + std::string_view data = R"( +import math; +let x:R^2, x = (-2, 3); +let y:R^2, y = (4, 3); +let s:R, s = dot(x,y); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", (TinyVector<2>{-2, 3}, TinyVector<2>{4, 3})); + } + + { // dot + tested_function_set.insert("dot:R^3*R^3"); + std::string_view data = R"( +import math; +let x:R^3, x = (-2, 3, 4); +let y:R^3, y = (4, 3, 5); +let s:R, s = dot(x,y); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", (TinyVector<3>{-2, 3, 4}, TinyVector<3>{4, 3, 5})); + } + MathModule math_module; bool missing_test = false; diff --git a/tests/test_MathModule.cpp b/tests/test_MathModule.cpp index e899190bf486ff605c1bd4f8af2ccc232cc806a8..fe87699887d6d587094bcd33999d403a20ce10b0 100644 --- a/tests/test_MathModule.cpp +++ b/tests/test_MathModule.cpp @@ -4,6 +4,8 @@ #include <language/modules/MathModule.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <set> + // clazy:excludeall=non-pod-global-static TEST_CASE("MathModule", "[language]") @@ -13,7 +15,7 @@ TEST_CASE("MathModule", "[language]") MathModule math_module; const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap(); - REQUIRE(name_builtin_function.size() == 22); + REQUIRE(name_builtin_function.size() == 25); SECTION("double -> double") { @@ -322,4 +324,61 @@ TEST_CASE("MathModule", "[language]") REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); } } + + SECTION("(R^d, R^d) -> double") + { + SECTION("dot:R^1*R^1") + { + TinyVector<1> arg0 = 3; + TinyVector<1> arg1 = 2; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("dot:R^1*R^1"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = (arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == result); + } + + SECTION("dot:R^2*R^2") + { + TinyVector<2> arg0{3, 2}; + TinyVector<2> arg1{-2, 5}; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("dot:R^2*R^2"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = (arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == result); + } + + SECTION("dot:R^3*R^3") + { + TinyVector<3> arg0{3, 2, 4}; + TinyVector<3> arg1{-2, 5, 2}; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + auto i_function = name_builtin_function.find("dot:R^3*R^3"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = (arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == result); + } + } }