diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp index a381123884250e9cd574495d7c3add369863aa59..656000c69cd45b85bcf72907850230c7e0f9615b 100644 --- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp @@ -15,7 +15,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy { const size_t parameter_id = std::get<size_t>(parameter_symbol.attributes().value()); - ASTNodeNaturalConversionChecker{node_sub_data_type, parameter_symbol.attributes().dataType()}; + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{node_sub_data_type, parameter_symbol.attributes().dataType()}; auto get_function_argument_converter_for = [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { @@ -78,13 +78,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy // LCOV_EXCL_STOP } } + case ASTNodeDataType::bool_t: { + if ((parameter_v.dimension() == 1)) { + return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, bool>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::int_t: { - if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if ((parameter_v.dimension() == 1)) { + return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, int64_t>>(parameter_id); + } else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ZeroType>>(parameter_id); } } - [[fallthrough]]; + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + case ASTNodeDataType::unsigned_int_t: { + if ((parameter_v.dimension() == 1)) { + return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, uint64_t>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::double_t: { + if ((parameter_v.dimension() == 1)) { + return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, double>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument dimension", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } } // LCOV_EXCL_START default: { @@ -110,13 +145,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy // LCOV_EXCL_STOP } } + case ASTNodeDataType::bool_t: { + if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, bool>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } case ASTNodeDataType::int_t: { - if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { + if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, int64_t>>(parameter_id); + } else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id); } } - [[fallthrough]]; + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + case ASTNodeDataType::unsigned_int_t: { + if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, uint64_t>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } + } + case ASTNodeDataType::double_t: { + if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) { + return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, double>>(parameter_id); + } else { + // LCOV_EXCL_START + throw ParseError("unexpected error: invalid argument type", + std::vector{node_sub_data_type.m_parent_node.begin()}); + // LCOV_EXCL_STOP + } } // LCOV_EXCL_START default: { diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index bb125b821361989e6e8bb898ba0b56208c32ef84..6a4075307e3af9f225cd493176fbd0a7873a3c75 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -116,10 +116,23 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver value); } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; + } else if constexpr (std::is_same_v<ExpectedValueType, TinyVector<1>>) { + if constexpr (std::is_same_v<ProvidedValueType, bool>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, int64_t>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, uint64_t>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, double>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else { + static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>); + exec_policy.currentContext()[m_argument_id] = + std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + } } else { - static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>); - exec_policy.currentContext()[m_argument_id] = - std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" + + demangle<ExpectedValueType>() + "'"); } return {}; } @@ -165,11 +178,25 @@ class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConver value); } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; + } else if constexpr (std::is_same_v<ExpectedValueType, TinyMatrix<1>>) { + if constexpr (std::is_same_v<ProvidedValueType, bool>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, int64_t>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, uint64_t>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else if constexpr (std::is_same_v<ProvidedValueType, double>) { + exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value)); + } else { + static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); + exec_policy.currentContext()[m_argument_id] = + std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + } } else { - static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); - exec_policy.currentContext()[m_argument_id] = - std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); + throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" + + demangle<ExpectedValueType>() + "'"); } + return {}; } diff --git a/tests/test_ASTNodeFunctionExpressionBuilder.cpp b/tests/test_ASTNodeFunctionExpressionBuilder.cpp index f19b6585613db943b77d684008974e737140686d..762202a671b1f7708b193ca8cd6c7aa599bfc780 100644 --- a/tests/test_ASTNodeFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeFunctionExpressionBuilder.cpp @@ -370,13 +370,30 @@ cat("foo", 2.5e-3); let f : R^1 -> R^1, x -> x+x; let x : R^1, x = 1; f(x); +let n:N, n=1; +f(true); +f(n); +f(2); +f(1.4); )"; std::string_view result = R"( (root:ASTNodeListProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::name:x:NameProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::true_kw:ValueProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::name:n:NameProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::integer:2:ValueProcessor) `-(language::function_evaluation:FunctionProcessor) +-(language::name:f:NameProcessor) - `-(language::name:x:NameProcessor) + `-(language::real:1.4:ValueProcessor) )"; CHECK_AST(data, result); @@ -424,13 +441,30 @@ f(x); let f : R^1x1 -> R^1x1, x -> x+x; let x : R^1x1, x = 1; f(x); +let n:N, n=1; +f(true); +f(n); +f(2); +f(1.4); )"; std::string_view result = R"( (root:ASTNodeListProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::name:x:NameProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::true_kw:ValueProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::name:n:NameProcessor) + +-(language::function_evaluation:FunctionProcessor) + | +-(language::name:f:NameProcessor) + | `-(language::integer:2:ValueProcessor) `-(language::function_evaluation:FunctionProcessor) +-(language::name:f:NameProcessor) - `-(language::name:x:NameProcessor) + `-(language::real:1.4:ValueProcessor) )"; CHECK_AST(data, result); @@ -1115,6 +1149,170 @@ prev(3 + .24); CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); } + + SECTION("B -> R^2") + { + std::string_view data = R"( +let f : R^2 -> R^2, x -> x; +f(true); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^2"}); + } + + SECTION("N -> R^2") + { + std::string_view data = R"( +let f : R^2 -> R^2, x -> x; +let n : N, n = 2; +f(n); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^2"}); + } + + SECTION("Z -> R^2") + { + std::string_view data = R"( +let f : R^2 -> R^2, x -> x; +f(-2); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2"}); + } + + SECTION("R -> R^2") + { + std::string_view data = R"( +let f : R^2 -> R^2, x -> x; +f(1.3); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^2"}); + } + + SECTION("B -> R^3") + { + std::string_view data = R"( +let f : R^3 -> R^3, x -> x; +f(true); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^3"}); + } + + SECTION("N -> R^3") + { + std::string_view data = R"( +let f : R^3 -> R^3, x -> x; +let n : N, n = 2; +f(n); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^3"}); + } + + SECTION("Z -> R^3") + { + std::string_view data = R"( +let f : R^3 -> R^3, x -> x; +f(-2); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^3"}); + } + + SECTION("R -> R^3") + { + std::string_view data = R"( +let f : R^3 -> R^3, x -> x; +f(1.3); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^3"}); + } + + SECTION("B -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(true); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^2x2"}); + } + + SECTION("N -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +let n : N, n = 2; +f(n); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^2x2"}); + } + + SECTION("Z -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(-2); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2x2"}); + } + + SECTION("R -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(1.3); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^2x2"}); + } + + SECTION("B -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(true); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^3x3"}); + } + + SECTION("N -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +let n : N, n = 2; +f(n); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^3x3"}); + } + + SECTION("Z -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(-2); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^3x3"}); + } + + SECTION("R -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(1.3); +)"; + + CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^3x3"}); + } } SECTION("arguments invalid tuple -> R^d conversion") diff --git a/tests/test_FunctionProcessor.cpp b/tests/test_FunctionProcessor.cpp index 03c8baa8ed5509c19520bb04d2860402e7308984..97cc8d6be42199f684739dba21b1b3f9a2cab644 100644 --- a/tests/test_FunctionProcessor.cpp +++ b/tests/test_FunctionProcessor.cpp @@ -374,6 +374,43 @@ let fx:R^1, fx = f(x); CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{3})); } + SECTION(" R^1 -> R^1 called with B argument") + { + std::string_view data = R"( +let f : R^1 -> R^1, x -> 2*x; +let fx:R^1, fx = f(true); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{true})); + } + + SECTION(" R^1 -> R^1 called with N argument") + { + std::string_view data = R"( +let f : R^1 -> R^1, x -> 2*x; +let n:N, n = 3; +let fx:R^1, fx = f(n); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{3})); + } + + SECTION(" R^1 -> R^1 called with Z argument") + { + std::string_view data = R"( +let f : R^1 -> R^1, x -> 2*x; +let fx:R^1, fx = f(-2); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{-2})); + } + + SECTION(" R^1 -> R^1 called with R argument") + { + std::string_view data = R"( +let f : R^1 -> R^1, x -> 2*x; +let fx:R^1, fx = f(1.3); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{1.3})); + } + SECTION(" R^2 -> R^2") { std::string_view data = R"( @@ -439,6 +476,43 @@ let fx:R^1x1, fx = f(x); CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); } + SECTION(" R^1x1 -> R^1x1 called with B argument") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let fx:R^1x1, fx = f(true); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{true})); + } + + SECTION(" R^1x1 -> R^1x1 called with N argument") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let n:N, n = 3; +let fx:R^1x1, fx = f(n); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); + } + + SECTION(" R^1x1 -> R^1x1 called with Z argument") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let fx:R^1x1, fx = f(-4); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{-4})); + } + + SECTION(" R^1x1 -> R^1x1 called with R argument") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> 2*x; +let fx:R^1x1, fx = f(-2.3); +)"; + CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{-2.3})); + } + SECTION(" R^2x2 -> R^2x2") { std::string_view data = R"(