From c013d0d68bdea24d3956a7ca6975876bad1ff6de Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 23 Sep 2020 19:08:00 +0200
Subject: [PATCH] Use function domains to check its PugsFunctionAdapter
 compatibility

Related to issue #21
---
 src/language/utils/PugsFunctionAdapter.hpp | 53 ++++++----------------
 tests/test_PugsFunctionAdapter.cpp         |  6 +--
 2 files changed, 18 insertions(+), 41 deletions(-)

diff --git a/src/language/utils/PugsFunctionAdapter.hpp b/src/language/utils/PugsFunctionAdapter.hpp
index 660cae99e..489066488 100644
--- a/src/language/utils/PugsFunctionAdapter.hpp
+++ b/src/language/utils/PugsFunctionAdapter.hpp
@@ -36,12 +36,14 @@ class PugsFunctionAdapter<OutputType(InputType...)>
 
   template <size_t I>
   [[nodiscard]] PUGS_INLINE static bool
-  _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept
+  _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
   {
     using Arg = std::tuple_element_t<I, InputTuple>;
 
     constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;
-    const ASTNodeDataType& arg_data_type                      = arg_expression.m_data_type;
+
+    Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
+    const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();
 
     return isNaturalConversion(expected_input_data_type, arg_data_type);
   }
@@ -55,53 +57,28 @@ class PugsFunctionAdapter<OutputType(InputType...)>
   }
 
   [[nodiscard]] PUGS_INLINE static bool
-  _checkValidInputDataType(const ASTNode& input_expression) noexcept
+  _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
   {
     if constexpr (NArgs == 1) {
-      return _checkValidArgumentDataType<0>(input_expression);
+      return _checkValidArgumentDataType<0>(input_domain_expression);
     } else {
-      if (input_expression.children.size() != NArgs) {
+      if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
+          (input_domain_expression.children.size() != NArgs)) {
         return false;
       }
 
       using IndexSequence = std::make_index_sequence<NArgs>;
-      return _checkAllInputDataType(input_expression, IndexSequence{});
+      return _checkAllInputDataType(input_domain_expression, IndexSequence{});
     }
   }
 
   [[nodiscard]] PUGS_INLINE static bool
-  _checkValidOutputDataType(const ASTNode& return_expression) noexcept
+  _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
   {
     constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
-    const ASTNodeDataType& return_data_type                    = return_expression.m_data_type;
+    const ASTNodeDataType& return_data_type                    = output_domain_expression.m_data_type.contentType();
 
-    if (not isNaturalConversion(return_data_type, expected_return_data_type)) {
-      if (expected_return_data_type == ASTNodeDataType::vector_t) {
-        if (return_data_type == ASTNodeDataType::list_t) {
-          if (expected_return_data_type.dimension() != return_expression.children.size()) {
-            return false;
-          } else {
-            for (const auto& child : return_expression.children) {
-              const ASTNodeDataType& data_type = child->m_data_type;
-              if (not isNaturalConversion(data_type, ast_node_data_type_from<double>)) {
-                return false;
-              }
-            }
-          }
-        } else if ((expected_return_data_type.dimension() == 1) and
-                   isNaturalConversion(return_data_type, ast_node_data_type_from<double>)) {
-          return true;
-        } else if (return_data_type == ast_node_data_type_from<int64_t>) {
-          // 0 is the only valid value for vectors
-          return (return_expression.string() == "0");
-        } else {
-          return false;
-        }
-      } else {
-        return false;
-      }
-    }
-    return true;
+    return isNaturalConversion(return_data_type, expected_return_data_type);
   }
 
   template <typename Arg, typename... RemainingArgs>
@@ -124,10 +101,10 @@ class PugsFunctionAdapter<OutputType(InputType...)>
   PUGS_INLINE static void
   _checkFunction(const FunctionDescriptor& function)
   {
-    bool has_valid_input  = _checkValidInputDataType(*function.definitionNode().children[0]);
-    bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]);
+    bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
+    bool has_valid_output       = _checkValidOutputDomain(*function.domainMappingNode().children[1]);
 
-    if (not(has_valid_input and has_valid_output)) {
+    if (not(has_valid_input_domain and has_valid_output)) {
       std::ostringstream error_message;
       error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
                     << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
diff --git a/tests/test_PugsFunctionAdapter.cpp b/tests/test_PugsFunctionAdapter.cpp
index 45229f8b5..3a7ca2813 100644
--- a/tests/test_PugsFunctionAdapter.cpp
+++ b/tests/test_PugsFunctionAdapter.cpp
@@ -275,7 +275,7 @@ let R3toR3zero: R^3 -> R^3, x -> 0;
   {
     std::string_view data = R"(
 let R1toR1: R^1 -> R^1, x -> x;
-let R3toR3: R^3 -> R^3, x -> 1;
+let R3toR3: R^3 -> R^3, x -> 0;
 let RRRtoR3: R*R*R -> R^3, (x,y,z) -> (x,y,z);
 let R3toR2: R^3 -> R^2, x -> (x[0],x[1]+x[2]);
 let RtoNS: R -> N*string, x -> (1, "foo");
@@ -322,9 +322,9 @@ let RtoR: R -> R, x -> 2*x;
       const TinyVector<3> x{2, 1, 3};
       FunctionSymbolId function_symbol_id(std::get<uint64_t>(i_symbol->attributes().value()), symbol_table);
 
-      REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<3>(TinyVector<3>)>::one_arg(function_symbol_id, x),
+      REQUIRE_THROWS_WITH(tests_adapter::TestBinary<TinyVector<2>(TinyVector<3>)>::one_arg(function_symbol_id, x),
                           "error: invalid function type\n"
-                          "note: expecting R^3 -> R^3\n"
+                          "note: expecting R^3 -> R^2\n"
                           "note: provided function R3toR3: R^3 -> R^3");
     }
 
-- 
GitLab