diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 62019550d16e36056ff2447b7797590952b6bfd8..248511e6c91c2eec08a94e87faa8179c06a0bf80 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -112,6 +112,84 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } }; + auto get_function_argument_converter_for_matrix = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + + if constexpr (std::is_same_v<ParameterT, TinyMatrix<1>>) { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == 1) and + (argument_node_sub_data_type.m_data_type.nbColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::bool_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, bool>>(argument_number); + } + case ASTNodeDataType::int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, int64_t>>(argument_number); + } + case ASTNodeDataType::unsigned_int_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, uint64_t>>(argument_number); + } + case ASTNodeDataType::double_t: { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, double>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } else { + switch (argument_node_sub_data_type.m_data_type) { + case ASTNodeDataType::matrix_t: { + if ((argument_node_sub_data_type.m_data_type.nbRows() == parameter_v.nbRows()) and + (argument_node_sub_data_type.m_data_type.nbColumns() == parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::list_t: { + if (argument_node_sub_data_type.m_parent_node.children.size() == + (parameter_v.nbRows() * parameter_v.nbColumns())) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ParameterT>>(argument_number); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if (argument_node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if (std::stoi(argument_node_sub_data_type.m_parent_node.string()) == 0) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(argument_number); + } + } + [[fallthrough]]; + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: invalid argument type", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } + }; + auto get_function_argument_to_string_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { return std::make_unique<FunctionArgumentToStringConverter>(argument_number); }; @@ -195,6 +273,25 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } // LCOV_EXCL_STOP } + } + case ASTNodeDataType::matrix_t: { + Assert(arg_data_type.nbRows() == arg_data_type.nbColumns()); + switch (arg_data_type.nbRows()) { + case 1: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<1>>>(argument_number); + } + case 2: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<2>>>(argument_number); + } + case 3: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyMatrix<3>>>(argument_number); + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector"); + } + // LCOV_EXCL_STOP + } } // LCOV_EXCL_START default: { @@ -254,6 +351,26 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.nbRows() == parameter_type.nbColumns()); + switch (parameter_type.nbRows()) { + case 1: { + return get_function_argument_converter_for_matrix(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_converter_for_matrix(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_converter_for_matrix(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: undefined parameter type for function", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_argument_to_string_converter(); } @@ -300,6 +417,27 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData // LCOV_EXCL_STOP } } + case ASTNodeDataType::matrix_t: { + Assert(parameter_type.contentType().nbRows() == parameter_type.contentType().nbColumns()); + switch (parameter_type.contentType().nbRows()) { + case 1: { + return get_function_argument_to_tuple_converter(TinyMatrix<1>{}); + } + case 2: { + return get_function_argument_to_tuple_converter(TinyMatrix<2>{}); + } + case 3: { + return get_function_argument_to_tuple_converter(TinyMatrix<3>{}); + } + // LCOV_EXCL_START + default: { + throw ParseError("unexpected error: unexpected tuple content for function: '" + dataTypeName(parameter_type) + + "'", + std::vector{argument_node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::string_t: { return get_function_argument_to_string_converter(); } diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index 3f8cac14f0558d7a2fa86cccfbfebe5199389347..af08e7806a4fd5daa01a7aecc98a84e52291040f 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -196,7 +196,8 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter list_value.emplace_back(std::move(v[i])); } exec_policy.currentContext()[m_argument_id] = std::move(list_value); - } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType>) { + } else if constexpr ((std::is_convertible_v<ContentT, ContentType>)and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { TupleType list_value; list_value.reserve(v.size()); for (size_t i = 0; i < v.size(); ++i) { @@ -209,7 +210,8 @@ class FunctionTupleArgumentConverter final : public IFunctionArgumentConverter demangle<ContentType>() + "'"); // LCOV_EXCL_STOP } - } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType>) { + } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ContentType> and + not is_tiny_matrix_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { throw UnexpectedError(std::string{"cannot convert '"} + demangle<ValueT>() + "' to '" + @@ -254,7 +256,7 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter using Vi_T = std::decay_t<decltype(vi)>; if constexpr (std::is_same_v<Vi_T, ContentType>) { list_value.emplace_back(vi); - } else if constexpr (is_tiny_vector_v<ContentType>) { + } else if constexpr (is_tiny_vector_v<ContentType> or is_tiny_matrix_v<ContentType>) { // LCOV_EXCL_START throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<Vi_T>() + "' to '" + demangle<ContentType>() + "'"); @@ -283,7 +285,8 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter } else if constexpr (std::is_same_v<ValueT, ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{v}); } else if constexpr (std::is_convertible_v<ValueT, ContentType> and not is_tiny_vector_v<ValueT> and - not is_tiny_vector_v<ContentType>) { + not is_tiny_vector_v<ContentType> and not is_tiny_matrix_v<ValueT> and + not is_tiny_matrix_v<ContentType>) { exec_policy.currentContext()[m_argument_id] = std::move(TupleType{static_cast<ContentType>(v)}); } else { // LCOV_EXCL_START diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 278c7274ff702cfc5e97f7b6075b5487e440580b..9b57a908e6341291d14e5e8f9c80aa8a87275dcc 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -229,6 +229,85 @@ class PugsFunctionAdapter<OutputType(InputType...)> } // LCOV_EXCL_STOP } + } else if constexpr (is_tiny_matrix_v<OutputType>) { + switch (data_type) { + case ASTNodeDataType::list_t: { + return [](DataVariant&& result) -> OutputType { + AggregateDataVariant& v = std::get<AggregateDataVariant>(result); + OutputType x; + + for (size_t i = 0, l = 0; i < x.dimension(); ++i) { + for (size_t j = 0; j < x.dimension(); ++j, ++l) { + std::visit( + [&](auto&& Aij) { + using Aij_T = std::decay_t<decltype(Aij)>; + if constexpr (std::is_arithmetic_v<Aij_T>) { + x(i, j) = Aij; + } else { + // LCOV_EXCL_START + throw UnexpectedError("expecting arithmetic value"); + // LCOV_EXCL_STOP + } + }, + v[l]); + } + } + return x; + }; + } + case ASTNodeDataType::matrix_t: { + return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; + } + case ASTNodeDataType::bool_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return + [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::unsigned_int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType(static_cast<double>(std::get<uint64_t>(result))); + }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::int_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { + return OutputType{static_cast<double>(std::get<int64_t>(result))}; + }; + } else { + // If this point is reached must be a 0 matrix + return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; }; + } + } + case ASTNodeDataType::double_t: { + if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) { + return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; + } else { + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP + } + } + // LCOV_EXCL_START + default: { + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + } + // LCOV_EXCL_STOP + } } else if constexpr (std::is_arithmetic_v<OutputType>) { switch (data_type) { case ASTNodeDataType::bool_t: { diff --git a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp index 2954cd8db2e691ea2206b30b158f9596a8290081..25115bfc2ff3011e582029ab495e17f4a9d108e2 100644 --- a/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -133,7 +133,7 @@ RtoR(true); } } - SECTION("R -> R1") + SECTION("R -> R^1") { SECTION("from R") { @@ -201,9 +201,77 @@ RtoR1(true); } } - SECTION("R1 -> R") + SECTION("R -> R^1x1") { - SECTION("from R1") + SECTION("from R") + { + std::string_view data = R"( +RtoR11(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +RtoR11(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +RtoR11(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +RtoR11(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:RtoR11:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^1 -> R") + { + SECTION("from R^1") { std::string_view data = R"( let x : R^1, x = 2; @@ -286,7 +354,92 @@ R1toR(true); } } - SECTION("R2 -> R") + SECTION("R^1x1 -> R") + { + SECTION("from R^1x1") + { + std::string_view data = R"( +let x : R^1x1, x = 2; +R11toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R") + { + std::string_view data = R"( +R11toR(1.); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::real:1.:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from Z") + { + std::string_view data = R"( +R11toR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from N") + { + std::string_view data = R"( +let n : N, n = 1; +R11toR(n); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::name:n:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from B") + { + std::string_view data = R"( +R11toR(true); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R11toR:NameProcessor) + `-(language::true_kw:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^2 -> R") { SECTION("from 0") { @@ -340,7 +493,63 @@ R2toR((1,2)); } } - SECTION("R3 -> R") + SECTION("R^2x2 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R22toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^2x2") + { + std::string_view data = R"( +let x:R^2x2, x = (1,2,3,4); +R22toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R22toR((1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R22toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("R^3 -> R") { SECTION("from 0") { @@ -395,6 +604,67 @@ R3toR((1,2,3)); } } + SECTION("R^3x3 -> R") + { + SECTION("from 0") + { + std::string_view data = R"( +R33toR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from R^3x3") + { + std::string_view data = R"( +let x:R^3x3, x = (1,2,3,4,5,6,7,8,9); +R33toR(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from list") + { + std::string_view data = R"( +R33toR((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33toR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("Z -> R") { SECTION("from Z") @@ -603,6 +873,87 @@ R3R2toR((1,2,3),0); } } + SECTION("R^3x3*R^2x2 -> R") + { + SECTION("from R^3x3*R^2x2") + { + std::string_view data = R"( +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +let y : R^2x2, y = (1,2,3,4); +R33R22toR(x,y); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::name:x:NameProcessor) + `-(language::name:y:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(R,R,R,R)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),(1,2,3,4)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("from (R,R,R,R,R,R,R,R,R)*(0)") + { + std::string_view data = R"( +R33R22toR((1,2,3,4,5,6,7,8,9),0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:R33R22toR:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | +-(language::integer:3:ValueProcessor) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | +-(language::integer:6:ValueProcessor) + | +-(language::integer:7:ValueProcessor) + | +-(language::integer:8:ValueProcessor) + | `-(language::integer:9:ValueProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + } + SECTION("string -> B") { std::string_view data = R"( @@ -1012,7 +1363,7 @@ tuple_builtinToB(t); CHECK_AST(data, result); } - SECTION("Z -> tuple(R1)") + SECTION("Z -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1); @@ -1028,7 +1379,7 @@ tuple_R1ToR(1); CHECK_AST(data, result); } - SECTION("R -> tuple(R1)") + SECTION("R -> tuple(R^1)") { std::string_view data = R"( tuple_R1ToR(1.2); @@ -1044,7 +1395,7 @@ tuple_R1ToR(1.2); CHECK_AST(data, result); } - SECTION("R1 -> tuple(R1)") + SECTION("R^1 -> tuple(R^1)") { std::string_view data = R"( let r:R^1, r = 3; @@ -1061,7 +1412,56 @@ tuple_R1ToR(r); CHECK_AST(data, result); } - SECTION("0 -> tuple(R2)") + SECTION("Z -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R -> tuple(R^1x1)") + { + std::string_view data = R"( +tuple_R11ToR(1.2); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::real:1.2:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^1x1 -> tuple(R^1x1)") + { + std::string_view data = R"( +let r:R^1x1, r = 3; +tuple_R11ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R11ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^2)") { std::string_view data = R"( tuple_R2ToR(0); @@ -1077,7 +1477,7 @@ tuple_R2ToR(0); CHECK_AST(data, result); } - SECTION("R2 -> tuple(R2)") + SECTION("R^2 -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1094,7 +1494,7 @@ tuple_R2ToR(r); CHECK_AST(data, result); } - SECTION("compound_list -> tuple(R2)") + SECTION("compound_list -> tuple(R^2)") { std::string_view data = R"( let r:R^2, r = (1,2); @@ -1113,7 +1513,59 @@ tuple_R2ToR((r,r)); CHECK_AST(data, result); } - SECTION("0 -> tuple(R3)") + SECTION("0 -> tuple(R^2x2)") + { + std::string_view data = R"( +tuple_R22ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^2x2 -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("compound_list -> tuple(R^2x2)") + { + std::string_view data = R"( +let r:R^2x2, r = (1,2,3,4); +tuple_R22ToR((r,r)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R22ToR:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::name:r:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("0 -> tuple(R^3)") { std::string_view data = R"( tuple_R3ToR(0); @@ -1129,7 +1581,7 @@ tuple_R3ToR(0); CHECK_AST(data, result); } - SECTION("R3 -> tuple(R3)") + SECTION("R^3 -> tuple(R^3)") { std::string_view data = R"( let r:R^3, r = (1,2,3); @@ -1146,6 +1598,39 @@ tuple_R3ToR(r); CHECK_AST(data, result); } + SECTION("0 -> tuple(R^3x3)") + { + std::string_view data = R"( +tuple_R33ToR(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("R^3x3 -> tuple(R^3x3)") + { + std::string_view data = R"( +let r:R^3x3, r = (1,2,3,4,5,6,7,8,9); +tuple_R33ToR(r); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:BuiltinFunctionProcessor) + +-(language::name:tuple_R33ToR:NameProcessor) + `-(language::name:r:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("FunctionSymbolId -> R") { std::string_view data = R"( diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 00f757fa68d935cdffbff876205e4ba5a16d1e68..b00a9649b6cf8a5f3b57c27d9d63e74ad5ed68b6 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -59,6 +59,52 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t); } + SECTION("R*R^2 -> R^2") + { + std::function c = [](double a, TinyVector<2> x) -> TinyVector<2> { return a * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(double, TinyVector<2>)> embedded_c{c}; + + double a_arg = 2.3; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + + SECTION("R^2x2*R^2 -> R^2") + { + std::function c = [](TinyMatrix<2> A, TinyVector<2> x) -> TinyVector<2> { return A * x; }; + + BuiltinFunctionEmbedder<TinyVector<2>(TinyMatrix<2>, TinyVector<2>)> embedded_c{c}; + + TinyMatrix<2> a_arg = {2.3, 1, -2, 3}; + TinyVector<2> x_arg{3, 2}; + + std::vector<DataVariant> args; + args.push_back(a_arg); + args.push_back(x_arg); + + DataVariant result = embedded_c.apply(args); + + REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg)); + REQUIRE(embedded_c.numberOfParameters() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t); + REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::matrix_t); + REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t); + } + SECTION("POD BuiltinFunctionEmbedder") { std::function c = [](double x, uint64_t i) -> bool { return x > i; }; diff --git a/tests/test_BuiltinFunctionRegister.hpp b/tests/test_BuiltinFunctionRegister.hpp index c1c1eb29b74fb9c5e6f8a170f76c7b4b2157c7b7..bc36e2ea29d37cbcd2464ddc1e080bff9e0dd591 100644 --- a/tests/test_BuiltinFunctionRegister.hpp +++ b/tests/test_BuiltinFunctionRegister.hpp @@ -68,6 +68,28 @@ class test_BuiltinFunctionRegister std::make_shared<BuiltinFunctionEmbedder<double(TinyVector<3>, TinyVector<2>)>>( [](TinyVector<3> x, TinyVector<2> y) -> double { return x[0] * y[1] + (y[0] - x[2]) * x[1]; }))); + m_name_builtin_function_map.insert( + std::make_pair("RtoR11", std::make_shared<BuiltinFunctionEmbedder<TinyMatrix<1>(double)>>( + [](double r) -> TinyMatrix<1> { return {r}; }))); + + m_name_builtin_function_map.insert( + std::make_pair("R11toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<1>)>>( + [](TinyMatrix<1> x) -> double { return x(0, 0); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<2>)>>( + [](TinyMatrix<2> x) -> double { return x(0, 0) + x(0, 1) + x(1, 0) + x(1, 1); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33toR", std::make_shared<BuiltinFunctionEmbedder<double(const TinyMatrix<3>&)>>( + [](const TinyMatrix<3>& x) -> double { return x(0, 0) + x(1, 1) + x(2, 2); }))); + + m_name_builtin_function_map.insert( + std::make_pair("R33R22toR", std::make_shared<BuiltinFunctionEmbedder<double(TinyMatrix<3>, TinyMatrix<2>)>>( + [](TinyMatrix<3> x, TinyMatrix<2> y) -> double { + return (x(0, 0) + x(1, 1) + x(2, 2)) * (y(0, 0) + y(0, 1) + y(1, 0) + y(1, 1)); + }))); + m_name_builtin_function_map.insert( std::make_pair("fidToR", std::make_shared<BuiltinFunctionEmbedder<double(const FunctionSymbolId&)>>( [](const FunctionSymbolId&) -> double { return 0; }))); @@ -116,6 +138,21 @@ class test_BuiltinFunctionRegister m_name_builtin_function_map.insert( std::make_pair("tuple_R3ToR", std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyVector<3>>)>>( [](const std::vector<TinyVector<3>>&) -> double { return 0; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R11ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<1>>&)>>( + [](const std::vector<TinyMatrix<1>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R22ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<2>>&)>>( + [](const std::vector<TinyMatrix<2>>&) -> double { return 1; }))); + + m_name_builtin_function_map.insert( + std::make_pair("tuple_R33ToR", + std::make_shared<BuiltinFunctionEmbedder<double(const std::vector<TinyMatrix<3>>)>>( + [](const std::vector<TinyMatrix<3>>&) -> double { return 0; }))); } void diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp index 3a7ca2813845fa72b0db312b6d259195c162c09d..4b04c7f5e088c1bdec314c91414fac4924b8bd6b 100644 --- a/tests/test_PugsFunctionAdapter.cpp +++ b/tests/test_PugsFunctionAdapter.cpp @@ -77,11 +77,17 @@ let ZplusZ: Z*Z -> Z, (x,y) -> x+y; let RplusR: R*R -> R, (x,y) -> x+y; let RRtoR2: R*R -> R^2, (x,y) -> (x+y, x-y); let R3times2: R^3 -> R^3, x -> 2*x; +let R33times2: R^3x3 -> R^3x3, x -> 2*x; let BtoR1: B -> R^1, b -> not b; +let BtoR11: B -> R^1x1, b -> not b; let NtoR1: N -> R^1, n -> n*n; +let NtoR11: N -> R^1x1, n -> n*n; let ZtoR1: Z -> R^1, z -> -z; +let ZtoR11: Z -> R^1x1, z -> -z; let RtoR1: R -> R^1, x -> x*x; +let RtoR11: R -> R^1x1, x -> x*x; let R3toR3zero: R^3 -> R^3, x -> 0; +let R33toR33zero: R^3x3 -> R^3x3, x -> 0; )"; string_input input{data, "test.pgs"}; @@ -194,6 +200,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == 2 * x); } + { + auto [i_symbol, found] = symbol_table->find("R33times2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{2, 3, 4, 1, 6, 5, 9, 7, 8}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == 2 * x); + } + { auto [i_symbol, found] = symbol_table->find("BtoR1", position); REQUIRE(found); @@ -218,6 +237,30 @@ let R3toR3zero: R^3 -> R^3, x -> 0; } } + { + auto [i_symbol, found] = symbol_table->find("BtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + { + const bool b = true; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + + { + const bool b = false; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + } + { auto [i_symbol, found] = symbol_table->find("NtoR1", position); REQUIRE(found); @@ -231,6 +274,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == n * n); } + { + auto [i_symbol, found] = symbol_table->find("NtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const uint64_t n = 4; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(uint64_t)>::one_arg(function_symbol_id, n); + + REQUIRE(result == n * n); + } + { auto [i_symbol, found] = symbol_table->find("ZtoR1", position); REQUIRE(found); @@ -244,6 +300,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == -z); } + { + auto [i_symbol, found] = symbol_table->find("ZtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const int64_t z = 3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(int64_t)>::one_arg(function_symbol_id, z); + + REQUIRE(result == -z); + } + { auto [i_symbol, found] = symbol_table->find("RtoR1", position); REQUIRE(found); @@ -257,6 +326,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == x * x); } + { + auto [i_symbol, found] = symbol_table->find("RtoR11", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 3.3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<1> result = tests_adapter::TestBinary<TinyMatrix<1>(double)>::one_arg(function_symbol_id, x); + + REQUIRE(result == x * x); + } + { auto [i_symbol, found] = symbol_table->find("R3toR3zero", position); REQUIRE(found); @@ -269,6 +351,19 @@ let R3toR3zero: R^3 -> R^3, x -> 0; REQUIRE(result == TinyVector<3>{0, 0, 0}); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR33zero", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyMatrix<3> x{1, 0, 0, 0, 1, 0, 0, 0, 1}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyMatrix<3> result = tests_adapter::TestBinary<TinyMatrix<3>(TinyMatrix<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == TinyMatrix<3>{0, 0, 0, 0, 0, 0, 0, 0, 0}); + } } SECTION("Errors calls") @@ -280,6 +375,7 @@ let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); let RtoNS: R -> N*string, x -> (1, "foo"); let RtoR: R -> R, x -> 2*x; +let R33toR22: R^3x3 -> R^2x2, x -> (x[0,0], x[0,1]+x[0,2], x[2,0]*x[1,1], x[2,1]+x[2,2]); )"; string_input input{data, "test.pgs"}; @@ -385,5 +481,19 @@ let RtoR: R -> R, x -> 2*x; "note: expecting R -> R^3\n" "note: provided function RtoR: R -> R"); } + + { + auto [i_symbol, found] = symbol_table->find("R33toR22", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<double(double)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R -> R\n" + "note: provided function R33toR22: R^3x3 -> R^2x2"); + } } }