From c013d0d68bdea24d3956a7ca6975876bad1ff6de Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 23 Sep 2020 19:08:00 +0200 Subject: [PATCH] Use function domains to check its PugsFunctionAdapter compatibility Related to issue #21 --- src/language/utils/PugsFunctionAdapter.hpp | 53 ++++++---------------- tests/test_PugsFunctionAdapter.cpp | 6 +-- 2 files changed, 18 insertions(+), 41 deletions(-) diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp index 660cae99e..489066488 100644 --- a/src/language/utils/PugsFunctionAdapter.hpp +++ b/src/language/utils/PugsFunctionAdapter.hpp @@ -36,12 +36,14 @@ class PugsFunctionAdapter<OutputType(InputType...)> template <size_t I> [[nodiscard]] PUGS_INLINE static bool - _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept + _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT) { using Arg = std::tuple_element_t<I, InputTuple>; constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>; - const ASTNodeDataType& arg_data_type = arg_expression.m_data_type; + + Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t); + const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType(); return isNaturalConversion(expected_input_data_type, arg_data_type); } @@ -55,53 +57,28 @@ class PugsFunctionAdapter<OutputType(InputType...)> } [[nodiscard]] PUGS_INLINE static bool - _checkValidInputDataType(const ASTNode& input_expression) noexcept + _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept { if constexpr (NArgs == 1) { - return _checkValidArgumentDataType<0>(input_expression); + return _checkValidArgumentDataType<0>(input_domain_expression); } else { - if (input_expression.children.size() != NArgs) { + if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or + (input_domain_expression.children.size() != NArgs)) { return false; } using IndexSequence = std::make_index_sequence<NArgs>; - return _checkAllInputDataType(input_expression, IndexSequence{}); + return _checkAllInputDataType(input_domain_expression, IndexSequence{}); } } [[nodiscard]] PUGS_INLINE static bool - _checkValidOutputDataType(const ASTNode& return_expression) noexcept + _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT) { constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>; - const ASTNodeDataType& return_data_type = return_expression.m_data_type; + const ASTNodeDataType& return_data_type = output_domain_expression.m_data_type.contentType(); - if (not isNaturalConversion(return_data_type, expected_return_data_type)) { - if (expected_return_data_type == ASTNodeDataType::vector_t) { - if (return_data_type == ASTNodeDataType::list_t) { - if (expected_return_data_type.dimension() != return_expression.children.size()) { - return false; - } else { - for (const auto& child : return_expression.children) { - const ASTNodeDataType& data_type = child->m_data_type; - if (not isNaturalConversion(data_type, ast_node_data_type_from<double>)) { - return false; - } - } - } - } else if ((expected_return_data_type.dimension() == 1) and - isNaturalConversion(return_data_type, ast_node_data_type_from<double>)) { - return true; - } else if (return_data_type == ast_node_data_type_from<int64_t>) { - // 0 is the only valid value for vectors - return (return_expression.string() == "0"); - } else { - return false; - } - } else { - return false; - } - } - return true; + return isNaturalConversion(return_data_type, expected_return_data_type); } template <typename Arg, typename... RemainingArgs> @@ -124,10 +101,10 @@ class PugsFunctionAdapter<OutputType(InputType...)> PUGS_INLINE static void _checkFunction(const FunctionDescriptor& function) { - bool has_valid_input = _checkValidInputDataType(*function.definitionNode().children[0]); - bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]); + bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]); + bool has_valid_output = _checkValidOutputDomain(*function.domainMappingNode().children[1]); - if (not(has_valid_input and has_valid_output)) { + if (not(has_valid_input_domain and has_valid_output)) { std::ostringstream error_message; error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>) diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp index 45229f8b5..3a7ca2813 100644 --- a/tests/test_PugsFunctionAdapter.cpp +++ b/tests/test_PugsFunctionAdapter.cpp @@ -275,7 +275,7 @@ let R3toR3zero: R^3 -> R^3, x -> 0; { std::string_view data = R"( let R1toR1: R^1 -> R^1, x -> x; -let R3toR3: R^3 -> R^3, x -> 1; +let R3toR3: R^3 -> R^3, x -> 0; let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z); let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]); let RtoNS: R -> N*string, x -> (1, "foo"); @@ -322,9 +322,9 @@ let RtoR: R -> R, x -> 2*x; const TinyVector<3> x{2, 1, 3}; FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table); - REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x), + REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<2>(TinyVector<3>)>::one_arg(function_symbol_id, x), "error: invalid function type\n" - "note: expecting R^3 -> R^3\n" + "note: expecting R^3 -> R^2\n" "note: provided function R3toR3: R^3 -> R^3"); } -- GitLab