diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index 8394d940dfd774f58fcd7a16f54c91e0eb5334ea..2d973e7052bfa946cceb0ef1503c2d32a68dc8c8 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -30,8 +30,6 @@ class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter< using MeshType = Mesh<Connectivity<Dimension>>; const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); - const auto flatten_args = Adapter::getFlattenArgs(function_symbol_id); - auto& expression = Adapter::getFunctionExpression(function_symbol_id); auto convert_result = Adapter::getResultConverter(expression.m_data_type); @@ -43,12 +41,12 @@ class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter< using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; - parallel_for(given_mesh.numberOfNodes(), [=, &expression, &flatten_args, &tokens](NodeId r) { + parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) { const int32_t t = tokens.acquire(); auto& execution_policy = context_list[t]; - Adapter::convertArgs(execution_policy.currentContext(), flatten_args, given_xr[r]); + Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]); auto result = expression.execute(execution_policy); xr[r] = convert_result(std::move(result)); diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index a1fedff2f0168da1d2fe58c97a5bfc91cd6d4154..16d1b550921be0d7a307924583c7b9f22f70136c 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -30,8 +30,6 @@ class InterpolateItemValue<OutputType(InputType)> : public PugsFunctionAdapter<O static inline ItemValue<OutputType, item_type> interpolate(const FunctionSymbolId& function_symbol_id, const ItemValue<const InputType, item_type>& position) { - const auto flatten_args = Adapter::getFlattenArgs(function_symbol_id); - auto& expression = Adapter::getFunctionExpression(function_symbol_id); auto convert_result = Adapter::getResultConverter(expression.m_data_type); @@ -44,12 +42,12 @@ class InterpolateItemValue<OutputType(InputType)> : public PugsFunctionAdapter<O ItemValue<OutputType, item_type> value(connectivity); using ItemId = ItemIdT<item_type>; - parallel_for(connectivity.template numberOf<item_type>(), [=, &expression, &flatten_args, &tokens](ItemId i) { + parallel_for(connectivity.template numberOf<item_type>(), [=, &expression, &tokens](ItemId i) { const int32_t t = tokens.acquire(); auto& execution_policy = context_list[t]; - Adapter::convertArgs(execution_policy.currentContext(), flatten_args, position[i]); + Adapter::convertArgs(execution_policy.currentContext(), position[i]); auto result = expression.execute(execution_policy); value[i] = convert_result(std::move(result)); diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 262977cdf45ee775833c0689cace9713019c3a69..90b7e209f2e2d1bf4a2076e52dbc4aea212dff7d 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -20,81 +20,53 @@ template <typename OutputType, typename... InputType> class PugsFunctionAdapter<OutputType(InputType...)> { protected: - using FlattenList = std::array<int32_t, sizeof...(InputType)>; + using InputTuple = std::tuple<std::decay_t<InputType>...>; + constexpr static size_t NArgs = std::tuple_size_v<InputTuple>; private: - template <typename T> - PUGS_INLINE static void - _flattenArgT(const T&, ExecutionPolicy::Context&, size_t&) - { - throw UnexpectedError("cannot flatten type " + demangle<T>()); - } - - template <size_t N> - PUGS_INLINE static void - _flattenArgT(const TinyVector<N>& t, ExecutionPolicy::Context& context, size_t& i_context) - { - for (size_t i = 0; i < N; ++i) { - context[i_context + i] = t[i]; - } - } - template <typename T, typename... Args> PUGS_INLINE static void - _convertArgs(const T& t, - const Args&&... args, - ExecutionPolicy::Context& context, - const FlattenList& flatten, - size_t i_context) + _convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args) { - if (flatten[sizeof...(args)]) { - _flattenArgT(t, context, i_context); - } else { - context[i_context++] = t; - } - if constexpr (sizeof...(args) > 0) { - _convertArgs(std::forward<Args>(args)..., context, flatten, i_context); + context[i_context++] = t; + if constexpr (sizeof...(Args) > 0) { + _convertArgs(context, i_context, std::forward<Args>(args)...); } } - template <typename Arg, typename... RemainingArgs> + template <size_t I> [[nodiscard]] PUGS_INLINE static bool - _checkValidArgumentDataType(const ASTNode& input_expression, FlattenList& flatten_list) noexcept + _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept { + using Arg = std::tuple_element_t<I, InputTuple>; + constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>; - const ASTNodeDataType& input_data_type = input_expression.m_data_type; + const ASTNodeDataType& arg_data_type = arg_expression.m_data_type; - constexpr size_t i_argument = sizeof...(InputType) - 1 - sizeof...(RemainingArgs); - flatten_list[i_argument] = false; + return isNaturalConversion(expected_input_data_type, arg_data_type); + } - if (not isNaturalConversion(expected_input_data_type, input_data_type)) { - if ((expected_input_data_type == ASTNodeDataType::vector_t) and (input_data_type == ASTNodeDataType::list_t)) { - flatten_list[i_argument] = true; - if (expected_input_data_type.dimension() != input_expression.children.size()) { - return false; - } else { - for (const auto& child : input_expression.children) { - const ASTNodeDataType& data_type = child->m_data_type; - if (not isNaturalConversion(ast_node_data_type_from<double>, data_type)) { - return false; - } - } - } - } else { - return false; - } - } - if constexpr (sizeof...(RemainingArgs) == 0) { - return true; - } else { - return false; - } + template <size_t... I> + [[nodiscard]] PUGS_INLINE static bool + _checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>) + { + Assert(NArgs == input_expression.children.size()); + return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...); } [[nodiscard]] PUGS_INLINE static bool - _checkValidInputDataType(const ASTNode& input_expression, FlattenList& flatten_list) noexcept + _checkValidInputDataType(const ASTNode& input_expression) noexcept { - return _checkValidArgumentDataType<InputType...>(input_expression, flatten_list); + if constexpr (NArgs == 1) { + return _checkValidArgumentDataType<0>(input_expression); + } else { + if (input_expression.children.size() != NArgs) { + return false; + } + + using IndexSequence = std::make_index_sequence<NArgs>; + return _checkAllInputDataType(input_expression, IndexSequence{}); + } } [[nodiscard]] PUGS_INLINE static bool @@ -116,7 +88,17 @@ class PugsFunctionAdapter<OutputType(InputType...)> } } } + } else if ((expected_return_data_type.dimension() == 1) and + isNaturalConversion(return_data_type, ast_node_data_type_from<double>)) { + return true; + } else if (return_data_type == ast_node_data_type_from<int64_t>) { + // 0 is the only valid value for vectors + return (return_expression.string() == "0"); + } else { + return false; } + } else { + return false; } } return true; @@ -127,7 +109,7 @@ class PugsFunctionAdapter<OutputType(InputType...)> _getCompoundTypeName() { if constexpr (sizeof...(RemainingArgs) > 0) { - return dataTypeName(ast_node_data_type_from<Arg>) + _getCompoundTypeName<RemainingArgs...>(); + return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>(); } else { return dataTypeName(ast_node_data_type_from<Arg>); } @@ -139,15 +121,10 @@ class PugsFunctionAdapter<OutputType(InputType...)> return _getCompoundTypeName<InputType...>(); } - protected: - [[nodiscard]] PUGS_INLINE static FlattenList - getFlattenArgs(const FunctionSymbolId& function_symbol_id) + PUGS_INLINE static void + _checkFunction(const FunctionDescriptor& function) { - auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()]; - - FlattenList flatten_list; - - bool has_valid_input = _checkValidInputDataType(*function.definitionNode().children[0], flatten_list); + bool has_valid_input = _checkValidInputDataType(*function.definitionNode().children[0]); bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]); if (not(has_valid_input and has_valid_output)) { @@ -159,14 +136,14 @@ class PugsFunctionAdapter<OutputType(InputType...)> << function.domainMappingNode().string() << rang::style::reset << std::ends; throw NormalError(error_message.str()); } - - return flatten_list; } + protected: [[nodiscard]] PUGS_INLINE static auto& getFunctionExpression(const FunctionSymbolId& function_symbol_id) { auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()]; + _checkFunction(function); return *function.definitionNode().children[1]; } @@ -188,10 +165,11 @@ class PugsFunctionAdapter<OutputType(InputType...)> template <typename... Args> PUGS_INLINE static void - convertArgs(ExecutionPolicy::Context& context, const FlattenList& flatten, const Args&... args) + convertArgs(ExecutionPolicy::Context& context, Args&&... args) { - static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type"); - _convertArgs(args..., context, flatten, 0); + static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>, + "unexpected input type"); + _convertArgs(context, 0, args...); } [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)> @@ -211,7 +189,9 @@ class PugsFunctionAdapter<OutputType(InputType...)> if constexpr (std::is_arithmetic_v<Vi_T>) { x[i] = vi; } else { + // LCOV_EXCL_START throw UnexpectedError("expecting arithmetic value"); + // LCOV_EXCL_STOP } }, v[i]); @@ -227,7 +207,10 @@ class PugsFunctionAdapter<OutputType(InputType...)> return [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; }; } else { - throw UnexpectedError("unexpected data_type"); + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP } } case ASTNodeDataType::unsigned_int_t: { @@ -236,7 +219,10 @@ class PugsFunctionAdapter<OutputType(InputType...)> return OutputType(static_cast<double>(std::get<uint64_t>(result))); }; } else { - throw UnexpectedError("unexpected data_type"); + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP } } case ASTNodeDataType::int_t: { @@ -253,12 +239,18 @@ class PugsFunctionAdapter<OutputType(InputType...)> if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; } else { - throw UnexpectedError("unexpected data_type"); + // LCOV_EXCL_START + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); + // LCOV_EXCL_STOP } } + // LCOV_EXCL_START default: { - throw UnexpectedError("unexpected data_type"); + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); } + // LCOV_EXCL_STOP } } else if constexpr (std::is_arithmetic_v<OutputType>) { switch (data_type) { @@ -274,14 +266,20 @@ class PugsFunctionAdapter<OutputType(InputType...)> case ASTNodeDataType::double_t: { return [](DataVariant&& result) -> OutputType { return std::get<double>(result); }; } + // LCOV_EXCL_START default: { - throw UnexpectedError("unexpected data_type"); + throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); } + // LCOV_EXCL_STOP } } else { static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type"); } } + + PugsFunctionAdapter() = delete; + virtual ~PugsFunctionAdapter() = delete; }; #endif // PUGS_FUNCTION_ADAPTER_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 195aa699acf565f8edd2ef5a80a515ede3092348..2dfd3234486b352e5b9f27e13e10001846008bc1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -68,6 +68,7 @@ add_executable (unit_tests test_NameProcessor.cpp test_OStreamProcessor.cpp test_PCG.cpp + test_PugsFunctionAdapter.cpp test_PugsAssert.cpp test_RevisionInfo.cpp test_SparseMatrixDescriptor.cpp diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45229f8b599a1517fcd4c2182de563601643250d --- /dev/null +++ b/tests/test_PugsFunctionAdapter.cpp @@ -0,0 +1,389 @@ +#include <catch2/catch.hpp> + +#include <language/ast/ASTBuilder.hpp> +#include <language/utils/PugsFunctionAdapter.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTModulesImporter.hpp> +#include <language/ast/ASTNodeDataTypeBuilder.hpp> +#include <language/ast/ASTNodeExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionEvaluationExpressionBuilder.hpp> +#include <language/ast/ASTNodeFunctionExpressionBuilder.hpp> +#include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/utils/ASTPrinter.hpp> +#include <utils/Demangle.hpp> + +#include <pegtl/string_input.hpp> + +// clazy:excludeall=non-pod-global-static + +namespace tests_adapter +{ +template <typename T> +class TestBinary; +template <typename OutputType, typename... InputType> +class TestBinary<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)> +{ + using Adapter = PugsFunctionAdapter<OutputType(InputType...)>; + + public: + template <typename ArgT> + static auto + one_arg(const FunctionSymbolId& function_symbol_id, const ArgT& x) + { + auto& expression = Adapter::getFunctionExpression(function_symbol_id); + auto convert_result = Adapter::getResultConverter(expression.m_data_type); + + Array<ExecutionPolicy> context_list = Adapter::getContextList(expression); + + auto& execution_policy = context_list[0]; + + Adapter::convertArgs(execution_policy.currentContext(), x); + auto result = expression.execute(execution_policy); + + return convert_result(std::move(result)); + } + + template <typename Arg1T, typename Arg2T> + static auto + two_args(const FunctionSymbolId& function_symbol_id, const Arg1T& x, const Arg2T& y) + { + auto& expression = Adapter::getFunctionExpression(function_symbol_id); + auto convert_result = Adapter::getResultConverter(expression.m_data_type); + + Array<ExecutionPolicy> context_list = Adapter::getContextList(expression); + + auto& execution_policy = context_list[0]; + + Adapter::convertArgs(execution_policy.currentContext(), x, y); + auto result = expression.execute(execution_policy); + + return convert_result(std::move(result)); + } +}; +} // namespace tests_adapter + +TEST_CASE("PugsFunctionAdapter", "[language]") +{ + SECTION("Valid calls") + { + std::string_view data = R"( +let Rtimes2: R -> R, x -> 2*x; +let BandB: B*B -> B, (a,b) -> a and b; +let NplusN: N*N -> N, (x,y) -> x+y; +let ZplusZ: Z*Z -> Z, (x,y) -> x+y; +let RplusR: R*R -> R, (x,y) -> x+y; +let RRtoR2: R*R -> R^2, (x,y) -> (x+y, x-y); +let R3times2: R^3 -> R^3, x -> 2*x; +let BtoR1: B -> R^1, b -> not b; +let NtoR1: N -> R^1, n -> n*n; +let ZtoR1: Z -> R^1, z -> -z; +let RtoR1: R -> R^1, x -> x*x; +let R3toR3zero: R^3 -> R^3, x -> 0; +)"; + string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + position position{internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this + + { + auto [i_symbol, found] = symbol_table->find("Rtimes2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 2; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + double result = tests_adapter::TestBinary<double(double)>::one_arg(function_symbol_id, x); + + REQUIRE(result == (2 * x)); + } + + { + auto [i_symbol, found] = symbol_table->find("BandB", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const bool a = true; + const bool b = false; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + double result = tests_adapter::TestBinary<bool(bool, bool)>::two_args(function_symbol_id, a, b); + + REQUIRE(result == (a and b)); + } + + { + auto [i_symbol, found] = symbol_table->find("NplusN", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const uint64_t x = 2; + const uint64_t y = 3; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + double result = tests_adapter::TestBinary<uint64_t(uint64_t, uint64_t)>::two_args(function_symbol_id, x, y); + + REQUIRE(result == (x + y)); + } + + { + auto [i_symbol, found] = symbol_table->find("ZplusZ", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const int64_t x = 2; + const int64_t y = 3; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + double result = tests_adapter::TestBinary<int64_t(int64_t, int64_t)>::two_args(function_symbol_id, x, y); + + REQUIRE(result == (x + y)); + } + + { + auto [i_symbol, found] = symbol_table->find("RplusR", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 2; + const double y = 3; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + double result = tests_adapter::TestBinary<double(double, double)>::two_args(function_symbol_id, x, y); + + REQUIRE(result == (x + y)); + } + + { + auto [i_symbol, found] = symbol_table->find("RRtoR2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 2; + const double y = 3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<2> result = + tests_adapter::TestBinary<TinyVector<2>(double, double)>::two_args(function_symbol_id, x, y); + + REQUIRE(result == TinyVector<2>{x + y, x - y}); + } + + { + auto [i_symbol, found] = symbol_table->find("R3times2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyVector<3> x{2, 3, 4}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<3> result = tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == 2 * x); + } + + { + auto [i_symbol, found] = symbol_table->find("BtoR1", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + { + const bool b = true; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<1> result = tests_adapter::TestBinary<TinyVector<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + + { + const bool b = false; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<1> result = tests_adapter::TestBinary<TinyVector<1>(bool)>::one_arg(function_symbol_id, b); + + REQUIRE(result == not b); + } + } + + { + auto [i_symbol, found] = symbol_table->find("NtoR1", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const uint64_t n = 4; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<1> result = tests_adapter::TestBinary<TinyVector<1>(uint64_t)>::one_arg(function_symbol_id, n); + + REQUIRE(result == n * n); + } + + { + auto [i_symbol, found] = symbol_table->find("ZtoR1", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const int64_t z = 3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<1> result = tests_adapter::TestBinary<TinyVector<1>(int64_t)>::one_arg(function_symbol_id, z); + + REQUIRE(result == -z); + } + + { + auto [i_symbol, found] = symbol_table->find("RtoR1", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 3.3; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<1> result = tests_adapter::TestBinary<TinyVector<1>(double)>::one_arg(function_symbol_id, x); + + REQUIRE(result == x * x); + } + + { + auto [i_symbol, found] = symbol_table->find("R3toR3zero", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyVector<3> x{1, 1, 1}; + + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + TinyVector<3> result = tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x); + + REQUIRE(result == TinyVector<3>{0, 0, 0}); + } + } + + SECTION("Errors calls") + { + std::string_view data = R"( +let R1toR1: R^1 -> R^1, x -> x; +let R3toR3: R^3 -> R^3, x -> 1; +let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); +let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); +let RtoNS: R -> N*string, x -> (1, "foo"); +let RtoR: R -> R, x -> 2*x; +)"; + string_input input{data, "test.pgs"}; + + auto ast = ASTBuilder::build(input); + + ASTModulesImporter{*ast}; + ASTNodeTypeCleaner<language::import_instruction>{*ast}; + + ASTSymbolTableBuilder{*ast}; + ASTNodeDataTypeBuilder{*ast}; + + ASTNodeTypeCleaner<language::var_declaration>{*ast}; + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; + ASTNodeExpressionBuilder{*ast}; + + std::shared_ptr<SymbolTable> symbol_table = ast->m_symbol_table; + + position position{internal::iterator{"fixture"}, "fixture"}; + position.byte = data.size(); // ensure that variables are declared at this + + { + auto [i_symbol, found] = symbol_table->find("R1toR1", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyVector<1> x{2}; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<double(TinyVector<1>)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R^1 -> R\n" + "note: provided function R1toR1: R^1 -> R^1"); + } + + { + auto [i_symbol, found] = symbol_table->find("R3toR3", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const TinyVector<3> x{2, 1, 3}; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R^3 -> R^3\n" + "note: provided function R3toR3: R^3 -> R^3"); + } + + { + auto [i_symbol, found] = symbol_table->find("RRRtoR3", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + const double y = 2; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(double, double)>::two_args(function_symbol_id, x, y), + "error: invalid function type\n" + "note: expecting R*R -> R^3\n" + "note: provided function RRRtoR3: R*R*R -> R^3"); + } + + { + auto [i_symbol, found] = symbol_table->find("R3toR2", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + const double y = 2; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(double, double)>::two_args(function_symbol_id, x, y), + "error: invalid function type\n" + "note: expecting R*R -> R^3\n" + "note: provided function R3toR2: R^3 -> R^2"); + } + + { + auto [i_symbol, found] = symbol_table->find("RtoNS", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<2>(double)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R -> R^2\n" + "note: provided function RtoNS: R -> N*string"); + } + + { + auto [i_symbol, found] = symbol_table->find("RtoR", position); + REQUIRE(found); + REQUIRE(i_symbol->attributes().dataType() == ASTNodeDataType::function_t); + + const double x = 1; + FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); + + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(double)>::one_arg(function_symbol_id, x), + "error: invalid function type\n" + "note: expecting R -> R^3\n" + "note: provided function RtoR: R -> R"); + } + } +}