From 6b2407fceaea1b3bd36ff6e6b8681e8a90bc5657 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Mon, 27 Jan 2020 19:06:59 +0100
Subject: [PATCH] Add vector treatment in functions (arguments and returned
 types)

---
 src/language/ASTNodeDataType.cpp              | 17 +++++
 src/language/ASTNodeDataType.hpp              |  2 +
 src/language/ASTNodeDataTypeBuilder.cpp       | 26 +++----
 src/language/ASTNodeDataTypeFlattener.cpp     |  2 +
 .../ASTNodeFunctionExpressionBuilder.cpp      | 74 +++++++++++++++++++
 ...STNodeListAffectationExpressionBuilder.cpp | 30 ++++++++
 6 files changed, 137 insertions(+), 14 deletions(-)

diff --git a/src/language/ASTNodeDataType.cpp b/src/language/ASTNodeDataType.cpp
index 46490c580..c0d118711 100644
--- a/src/language/ASTNodeDataType.cpp
+++ b/src/language/ASTNodeDataType.cpp
@@ -1,5 +1,22 @@
+#include <ASTNode.hpp>
 #include <ASTNodeDataType.hpp>
 
+#include <PEGGrammar.hpp>
+
+ASTNodeDataType
+getVectorDataType(const ASTNode& type_node)
+{
+  if (not(type_node.is_type<language::vector_type>() and (type_node.children.size() == 2))) {
+    throw parse_error("unexpected node type", type_node.begin());
+  }
+  ASTNode& dimension_node = *type_node.children[1];
+  if (not dimension_node.is_type<language::integer>()) {
+    throw parse_error("unexpected non integer constant dimension", dimension_node.begin());
+  }
+  const size_t dimension = std::stol(dimension_node.string());
+  return ASTNodeDataType{ASTNodeDataType::vector_t, dimension};
+}
+
 std::string
 dataTypeName(const ASTNodeDataType& data_type)
 {
diff --git a/src/language/ASTNodeDataType.hpp b/src/language/ASTNodeDataType.hpp
index 79ae74453..2d220a4d9 100644
--- a/src/language/ASTNodeDataType.hpp
+++ b/src/language/ASTNodeDataType.hpp
@@ -52,6 +52,8 @@ class ASTNodeDataType
   ~ASTNodeDataType() = default;
 };
 
+ASTNodeDataType getVectorDataType(const ASTNode& type_node);
+
 std::string dataTypeName(const ASTNodeDataType& data_type);
 
 ASTNodeDataType dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& data_type_2);
diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp
index b7e349374..c1ad69879 100644
--- a/src/language/ASTNodeDataTypeBuilder.cpp
+++ b/src/language/ASTNodeDataTypeBuilder.cpp
@@ -38,12 +38,7 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo
     } else if (type_node.is_type<language::R_set>()) {
       data_type = ASTNodeDataType::double_t;
     } else if (type_node.is_type<language::vector_type>()) {
-      ASTNode& dimension_node = *type_node.children[1];
-      if (not dimension_node.is_type<language::integer>()) {
-        throw parse_error("unexpected non integer constant dimension", dimension_node.begin());
-      }
-      const size_t dimension = std::stol(dimension_node.string());
-      data_type              = ASTNodeDataType{ASTNodeDataType::vector_t, dimension};
+      data_type = getVectorDataType(type_node);
     } else if (type_node.is_type<language::string_type>()) {
       data_type = ASTNodeDataType::string_t;
     }
@@ -98,7 +93,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
       } else if (n.is_type<language::integer>()) {
         n.m_data_type = ASTNodeDataType::int_t;
       } else if (n.is_type<language::vector_type>()) {
-        n.m_data_type = ASTNodeDataType::vector_t;
+        n.m_data_type = getVectorDataType(n);
       } else if (n.is_type<language::literal>()) {
         n.m_data_type = ASTNodeDataType::string_t;
       } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) {
@@ -134,7 +129,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
         }
 
         const size_t nb_parameter_domains =
-          (parameters_domain_node.children.size() > 0) ? parameters_domain_node.children.size() : 1;
+          (parameters_domain_node.is_type<language::type_expression>()) ? parameters_domain_node.children.size() : 1;
         const size_t nb_parameter_names =
           (parameters_name_node.children.size() > 0) ? parameters_name_node.children.size() : 1;
 
@@ -159,7 +154,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
           } else if (type_node.is_type<language::R_set>()) {
             data_type = ASTNodeDataType::double_t;
           } else if (type_node.is_type<language::vector_type>()) {
-            data_type = ASTNodeDataType::vector_t;
+            data_type = getVectorDataType(type_node);
           } else if (type_node.is_type<language::string_type>()) {
             data_type = ASTNodeDataType::string_t;
           }
@@ -180,10 +175,10 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
           i_symbol->attributes().setDataType(data_type);
         };
 
-        if (parameters_domain_node.children.size() == 0) {
+        if (nb_parameter_domains == 1) {
           simple_type_allocator(parameters_domain_node, parameters_name_node);
         } else {
-          for (size_t i = 0; i < parameters_domain_node.children.size(); ++i) {
+          for (size_t i = 0; i < nb_parameter_domains; ++i) {
             simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]);
           }
         }
@@ -229,7 +224,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
           } else if (image_node.is_type<language::R_set>()) {
             value_type = ASTNodeDataType::double_t;
           } else if (image_node.is_type<language::vector_type>()) {
-            value_type = ASTNodeDataType::vector_t;
+            value_type = getVectorDataType(image_node);
           } else if (image_node.is_type<language::string_type>()) {
             value_type = ASTNodeDataType::string_t;
           }
@@ -351,7 +346,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
         Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t);
 
         ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
-        if (image_domain_node.children.size() > 0) {
+        if (image_domain_node.is_type<language::type_expression>()) {
           data_type = image_domain_node.m_data_type;
         } else {
           if (image_domain_node.is_type<language::B_set>()) {
@@ -362,6 +357,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
             data_type = ASTNodeDataType::unsigned_int_t;
           } else if (image_domain_node.is_type<language::R_set>()) {
             data_type = ASTNodeDataType::double_t;
+          } else if (image_domain_node.is_type<language::vector_type>()) {
+            data_type = getVectorDataType(image_domain_node);
           } else if (image_domain_node.is_type<language::string_type>()) {
             data_type = ASTNodeDataType::string_t;
           }
@@ -390,7 +387,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
         throw parse_error(message.str(), n.begin());
       }
     } else if (n.is_type<language::B_set>() or n.is_type<language::Z_set>() or n.is_type<language::N_set>() or
-               n.is_type<language::R_set>() or n.is_type<language::string_type>()) {
+               n.is_type<language::R_set>() or n.is_type<language::string_type>() or
+               n.is_type<language::vector_type>()) {
       n.m_data_type = ASTNodeDataType::typename_t;
     } else if (n.is_type<language::name_list>() or n.is_type<language::function_argument_list>()) {
       n.m_data_type = ASTNodeDataType::void_t;
diff --git a/src/language/ASTNodeDataTypeFlattener.cpp b/src/language/ASTNodeDataTypeFlattener.cpp
index 0545aa25c..1f1002f80 100644
--- a/src/language/ASTNodeDataTypeFlattener.cpp
+++ b/src/language/ASTNodeDataTypeFlattener.cpp
@@ -39,6 +39,8 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT
             data_type = ASTNodeDataType::unsigned_int_t;
           } else if (image_sub_domain->is_type<language::R_set>()) {
             data_type = ASTNodeDataType::double_t;
+          } else if (image_sub_domain->is_type<language::vector_type>()) {
+            data_type = getVectorDataType(*image_sub_domain);
           } else if (image_sub_domain->is_type<language::string_type>()) {
             data_type = ASTNodeDataType::string_t;
           }
diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp
index 8b3787fd5..34a1b8571 100644
--- a/src/language/ASTNodeFunctionExpressionBuilder.cpp
+++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp
@@ -43,6 +43,27 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy
     }
   };
 
+  auto get_function_argument_converter_for_vector =
+    [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> {
+    using ParameterT = std::decay_t<decltype(parameter_v)>;
+    switch (node_sub_data_type.m_data_type) {
+    case ASTNodeDataType::vector_t: {
+      if (node_sub_data_type.m_data_type.dimension() == parameter_v.dimension()) {
+        return std::make_unique<FunctionArgumentConverter<ParameterT, ParameterT>>(parameter_id);
+      } else {
+        throw parse_error("invalid argument dimension (expected " + std::to_string(parameter_v.dimension()) +
+                            ", provided " + std::to_string(node_sub_data_type.m_data_type.dimension()) + ")",
+                          std::vector{node_sub_data_type.m_parent_node.begin()});
+      }
+    }
+      // LCOV_EXCL_START
+    default: {
+      throw parse_error("invalid argument type", std::vector{node_sub_data_type.m_parent_node.begin()});
+    }
+      // LCOV_EXCL_STOP
+    }
+  };
+
   auto get_function_argument_converter_for_string = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
     switch (node_sub_data_type.m_data_type) {
     case ASTNodeDataType::bool_t: {
@@ -85,6 +106,22 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy
     case ASTNodeDataType::string_t: {
       return get_function_argument_converter_for_string();
     }
+    case ASTNodeDataType::vector_t: {
+      switch (parameter_symbol.attributes().dataType().dimension()) {
+      case 1: {
+        return get_function_argument_converter_for_vector(TinyVector<1>{});
+      }
+      case 2: {
+        return get_function_argument_converter_for_vector(TinyVector<2>{});
+      }
+      case 3: {
+        return get_function_argument_converter_for_vector(TinyVector<3>{});
+      }
+      default: {
+        throw parse_error("unexpected error: invalid parameter dimension", std::vector{m_node.begin()});
+      }
+      }
+    }
 
       // LCOV_EXCL_START
     default: {
@@ -194,6 +231,25 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType ex
     }
   };
 
+  auto get_function_processor_for_expression_vector = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> {
+    using ReturnT = std::decay_t<decltype(return_v)>;
+    switch (expression_value_type) {
+    case ASTNodeDataType::vector_t: {
+      if (expression_value_type.dimension() == return_v.dimension()) {
+        return std::make_unique<FunctionExpressionProcessor<ReturnT, ReturnT>>(function_component_expression);
+      } else {
+        throw parse_error("invalid dimension for returned vector", std::vector{function_component_expression.begin()});
+      }
+    }
+      // LCOV_EXCL_START
+    default: {
+      throw parse_error("unexpected error: undefined expression value type for function",
+                        std::vector{node.children[1]->begin()});
+    }
+      // LCOV_EXCL_STOP
+    }
+  };
+
   auto get_function_processor_for_value = [&]() {
     switch (return_value_type) {
     case ASTNodeDataType::bool_t: {
@@ -208,6 +264,22 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType ex
     case ASTNodeDataType::double_t: {
       return get_function_processor_for_expression_value(double{});
     }
+    case ASTNodeDataType::vector_t: {
+      switch (return_value_type.dimension()) {
+      case 1: {
+        return get_function_processor_for_expression_vector(TinyVector<1>{});
+      }
+      case 2: {
+        return get_function_processor_for_expression_vector(TinyVector<2>{});
+      }
+      case 3: {
+        return get_function_processor_for_expression_vector(TinyVector<3>{});
+      }
+      default: {
+        throw parse_error("unexpected error: invalid dimension in returned type", std::vector{node.begin()});
+      }
+      }
+    }
     case ASTNodeDataType::string_t: {
       return get_function_processor_for_expression_value(std::string{});
     }
@@ -248,6 +320,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
       return_value_type = ASTNodeDataType::unsigned_int_t;
     } else if (image_domain_node.is_type<language::R_set>()) {
       return_value_type = ASTNodeDataType::double_t;
+    } else if (image_domain_node.is_type<language::vector_type>()) {
+      return_value_type = getVectorDataType(image_domain_node);
     } else if (image_domain_node.is_type<language::string_type>()) {
       return_value_type = ASTNodeDataType::string_t;
     }
diff --git a/src/language/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ASTNodeListAffectationExpressionBuilder.cpp
index a0bbeb88a..a39f53c6a 100644
--- a/src/language/ASTNodeListAffectationExpressionBuilder.cpp
+++ b/src/language/ASTNodeListAffectationExpressionBuilder.cpp
@@ -41,6 +41,16 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
     }
   };
 
+  auto add_affectation_processor_for_vector_data = [&](const auto& value,
+                                                       const ASTNodeSubDataType& node_sub_data_type) {
+    using ValueT = std::decay_t<decltype(value)>;
+    if (node_sub_data_type.m_data_type.dimension() == value.dimension()) {
+      list_affectation_processor->template add<ValueT, ValueT>(value_node);
+    } else {
+      throw parse_error("invalid dimension", std::vector{node_sub_data_type.m_parent_node.begin()});
+    }
+  };
+
   auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) {
     if constexpr (std::is_same_v<OperatorT, language::eq_op> or std::is_same_v<OperatorT, language::pluseq_op>) {
       switch (node_sub_data_type.m_data_type) {
@@ -98,6 +108,26 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
       add_affectation_processor_for_data(double{}, node_sub_data_type);
       break;
     }
+    case ASTNodeDataType::vector_t: {
+      switch (value_type.dimension()) {
+      case 1: {
+        add_affectation_processor_for_vector_data(TinyVector<1>{}, node_sub_data_type);
+        break;
+      }
+      case 2: {
+        add_affectation_processor_for_vector_data(TinyVector<2>{}, node_sub_data_type);
+        break;
+      }
+      case 3: {
+        add_affectation_processor_for_vector_data(TinyVector<3>{}, node_sub_data_type);
+        break;
+      }
+      default: {
+        throw parse_error("invalid dimension", std::vector{value_node.begin()});
+      }
+      }
+      break;
+    }
     case ASTNodeDataType::string_t: {
       add_affectation_processor_for_string_data(node_sub_data_type);
       break;
-- 
GitLab