Skip to content
Snippets Groups Projects
Commit ba39b67b authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add doubleDot:R^dxd*R^dxd -> R to the language

parent 42fd7137
No related branches found
No related tags found
1 merge request!198Add TinyMatrix's double-dot product
...@@ -71,6 +71,18 @@ MathModule::MathModule() ...@@ -71,6 +71,18 @@ MathModule::MathModule()
return dot(x, y); return dot(x, y);
})); }));
this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<1> A, const TinyMatrix<1> B) -> double {
return doubleDot(A, B);
}));
this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<2> A, const TinyMatrix<2> B) -> double {
return doubleDot(A, B);
}));
this->_addBuiltinFunction("doubleDot", std::function([](const TinyMatrix<3>& A, const TinyMatrix<3>& B) -> double {
return doubleDot(A, B);
}));
this->_addBuiltinFunction("tensorProduct", this->_addBuiltinFunction("tensorProduct",
std::function([](const TinyVector<1> x, const TinyVector<1> y) -> TinyMatrix<1> { std::function([](const TinyVector<1> x, const TinyVector<1> y) -> TinyMatrix<1> {
return tensorProduct(x, y); return tensorProduct(x, y);
......
...@@ -362,7 +362,43 @@ let s:R, s = dot(x,y); ...@@ -362,7 +362,43 @@ let s:R, s = dot(x,y);
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", dot(TinyVector<3>{-2, 3, 4}, TinyVector<3>{4, 3, 5})); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", dot(TinyVector<3>{-2, 3, 4}, TinyVector<3>{4, 3, 5}));
} }
{ // dot { // double-dot
tested_function_set.insert("doubleDot:R^1x1*R^1x1");
std::string_view data = R"(
import math;
let A1:R^1x1, A1 = [[-2]];
let A2:R^1x1, A2 = [[4]];
let s:R, s = doubleDot(A1,A2);
)";
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", double{-2 * 4});
}
{ // double-dot
tested_function_set.insert("doubleDot:R^2x2*R^2x2");
std::string_view data = R"(
import math;
let A1:R^2x2, A1 = [[-2, 3],[5,-2]];
let A2:R^2x2, A2 = [[4, 3],[7,3]];
let s:R, s = doubleDot(A1,A2);
)";
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s",
doubleDot(TinyMatrix<2>{-2, 3, 5, -2}, TinyMatrix<2>{4, 3, 7, 3}));
}
{ // double-dot
tested_function_set.insert("doubleDot:R^3x3*R^3x3");
std::string_view data = R"(
import math;
let A1:R^3x3, A1 = [[-2, 3, 4],[1,2,3],[6,3,2]];
let A2:R^3x3, A2 = [[4, 3, 5],[2, 3, 1],[2, 6, 1]];
let s:R, s = doubleDot(A1,A2);
)";
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s",
doubleDot(TinyMatrix<3>{-2, 3, 4, 1, 2, 3, 6, 3, 2},
TinyMatrix<3>{4, 3, 5, 2, 3, 1, 2, 6, 1}));
}
{ // tensor product
tested_function_set.insert("tensorProduct:R^1*R^1"); tested_function_set.insert("tensorProduct:R^1*R^1");
std::string_view data = R"( std::string_view data = R"(
import math; import math;
...@@ -373,7 +409,7 @@ let s:R^1x1, s = tensorProduct(x,y); ...@@ -373,7 +409,7 @@ let s:R^1x1, s = tensorProduct(x,y);
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", TinyMatrix<1>{-2 * 4}); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", TinyMatrix<1>{-2 * 4});
} }
{ // dot { // tensor product
tested_function_set.insert("tensorProduct:R^2*R^2"); tested_function_set.insert("tensorProduct:R^2*R^2");
std::string_view data = R"( std::string_view data = R"(
import math; import math;
...@@ -384,7 +420,7 @@ let s:R^2x2, s = tensorProduct(x,y); ...@@ -384,7 +420,7 @@ let s:R^2x2, s = tensorProduct(x,y);
CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", tensorProduct(TinyVector<2>{-2, 3}, TinyVector<2>{4, 3})); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "s", tensorProduct(TinyVector<2>{-2, 3}, TinyVector<2>{4, 3}));
} }
{ // dot { // tensor product
tested_function_set.insert("tensorProduct:R^3*R^3"); tested_function_set.insert("tensorProduct:R^3*R^3");
std::string_view data = R"( std::string_view data = R"(
import math; import math;
......
...@@ -13,7 +13,7 @@ TEST_CASE("MathModule", "[language]") ...@@ -13,7 +13,7 @@ TEST_CASE("MathModule", "[language]")
MathModule math_module; MathModule math_module;
const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap(); const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap();
REQUIRE(name_builtin_function.size() == 45); REQUIRE(name_builtin_function.size() == 48);
SECTION("Z -> N") SECTION("Z -> N")
{ {
...@@ -458,6 +458,69 @@ TEST_CASE("MathModule", "[language]") ...@@ -458,6 +458,69 @@ TEST_CASE("MathModule", "[language]")
} }
} }
SECTION("(R^dxd, R^dxd) -> double")
{
SECTION("doubleDot:R^1x1*R^1x1")
{
TinyMatrix<1> arg0{3};
TinyMatrix<1> arg1{2};
DataVariant arg0_variant = arg0;
DataVariant arg1_variant = arg1;
auto i_function = name_builtin_function.find("doubleDot:R^1x1*R^1x1");
REQUIRE(i_function != name_builtin_function.end());
IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant});
const double result = doubleDot(arg0, arg1);
REQUIRE(std::get<double>(result_variant) == result);
}
SECTION("doubleDot:R^2x2*R^2x2")
{
TinyMatrix<2> arg0{+3, +2, //
-1, +4};
TinyMatrix<2> arg1{-2, -5, //
+7, +1.3};
DataVariant arg0_variant = arg0;
DataVariant arg1_variant = arg1;
auto i_function = name_builtin_function.find("doubleDot:R^2x2*R^2x2");
REQUIRE(i_function != name_builtin_function.end());
IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant});
const double result = doubleDot(arg0, arg1);
REQUIRE(std::get<double>(result_variant) == result);
}
SECTION("doubleDot:R^3x3*R^3x3")
{
TinyMatrix<3> arg0{+3, +2, +4, //
-1, +3, -6, //
+2, +5, +1};
TinyMatrix<3> arg1{-2, +5, +2, //
+1, +7, -2, //
+7, -1, +3};
DataVariant arg0_variant = arg0;
DataVariant arg1_variant = arg1;
auto i_function = name_builtin_function.find("doubleDot:R^3x3*R^3x3");
REQUIRE(i_function != name_builtin_function.end());
IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant});
const double result = doubleDot(arg0, arg1);
REQUIRE(std::get<double>(result_variant) == result);
}
}
SECTION("(R^d, R^d) -> R^dxd") SECTION("(R^d, R^d) -> R^dxd")
{ {
SECTION("tensorProduct:R^1*R^1") SECTION("tensorProduct:R^1*R^1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment