From bc6a2d3c03a98c890ecd32b4d411dd17b58f0d12 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 28 Jan 2020 19:13:39 +0100
Subject: [PATCH] Add tuple expressions for vector initialization

This allows to write

``
let f : R -> R^3, x -> (x, 2*x+1, -2);
R^3 x = f(2.3);
``

or even

``
let f : R -> R^3*R, x -> ((x, 2*x+1, -2), x-1);
R^3*R (x,t) = f(2.3);
``

Note that it is not completely functional since
``
let f : R -> R^3, x -> (x, 2*x+1, -2);
R^3 x = 2*f(2.3);
``
does not work!
The return type of `f(2.3)` is incorrect (should be an R^3 but
remains a `typename` which is improper for these calculations).
---
 src/language/ASTBuilder.cpp                   |  1 +
 src/language/ASTNodeDataTypeBuilder.cpp       | 49 ++++++++++++++-----
 src/language/ASTNodeExpressionBuilder.cpp     | 19 +++++++
 .../ASTNodeFunctionExpressionBuilder.cpp      | 15 ++++--
 src/language/PEGGrammar.hpp                   |  4 +-
 .../node_processor/TupleToVectorProcessor.hpp | 45 +++++++++++++++++
 6 files changed, 115 insertions(+), 18 deletions(-)
 create mode 100644 src/language/node_processor/TupleToVectorProcessor.hpp

diff --git a/src/language/ASTBuilder.cpp b/src/language/ASTBuilder.cpp
index 6bacf4c07..eb853444e 100644
--- a/src/language/ASTBuilder.cpp
+++ b/src/language/ASTBuilder.cpp
@@ -223,6 +223,7 @@ using selector = parse_tree::selector<
                                 N_set,
                                 Z_set,
                                 R_set,
+                                tuple_expression,
                                 vector_type,
                                 string_type,
                                 cout_kw,
diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp
index c1ad69879..0b4de8331 100644
--- a/src/language/ASTNodeDataTypeBuilder.cpp
+++ b/src/language/ASTNodeDataTypeBuilder.cpp
@@ -94,6 +94,21 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
         n.m_data_type = ASTNodeDataType::int_t;
       } else if (n.is_type<language::vector_type>()) {
         n.m_data_type = getVectorDataType(n);
+
+      } else if (n.is_type<language::tuple_expression>()) {
+        for (auto&& child : n.children) {
+          this->_buildNodeDataTypes(*child);
+        }
+        for (auto&& child : n.children) {
+          ASTNodeNaturalConversionChecker{*child, child->m_data_type, ASTNodeDataType::double_t};
+        }
+
+        if (n.children.size() <= 3) {
+          n.m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, n.children.size()};
+        } else {
+          throw parse_error("invalid vector dimension (must be lesser than 3)", n.begin());
+        }
+
       } else if (n.is_type<language::literal>()) {
         n.m_data_type = ASTNodeDataType::string_t;
       } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) {
@@ -194,23 +209,33 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
         ASTNode& image_domain_node     = *function_descriptor.domainMappingNode().children[1];
         ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1];
 
+        this->_buildNodeDataTypes(image_domain_node);
+        for (auto& child : image_domain_node.children) {
+          this->_buildNodeDataTypes(*child);
+        }
+
         const size_t nb_image_domains =
           (image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1;
         const size_t nb_image_expressions =
           (image_expression_node.is_type<language::expression_list>()) ? image_expression_node.children.size() : 1;
 
         if (nb_image_domains != nb_image_expressions) {
-          std::ostringstream message;
-          message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow
-                  << image_domain_node.string() << rang::style::reset << rang::style::bold
-                  << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow
-                  << image_expression_node.string() << rang::style::reset << std::ends;
-          throw parse_error(message.str(), image_domain_node.begin());
-        }
-
-        this->_buildNodeDataTypes(image_domain_node);
-        for (auto& child : image_domain_node.children) {
-          this->_buildNodeDataTypes(*child);
+          if (image_domain_node.is_type<language::vector_type>()) {
+            ASTNodeDataType image_type = getVectorDataType(image_domain_node);
+            if (image_type.dimension() != nb_image_expressions) {
+              std::ostringstream message;
+              message << "note: expecting " << image_domain_node.m_data_type.dimension() << " expressions found "
+                      << nb_image_expressions << std::ends;
+              throw parse_error(message.str(), image_domain_node.begin());
+            }
+          } else {
+            std::ostringstream message;
+            message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow
+                    << image_domain_node.string() << rang::style::reset << rang::style::bold
+                    << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow
+                    << image_expression_node.string() << rang::style::reset << std::ends;
+            throw parse_error(message.str(), image_domain_node.begin());
+          }
         }
 
         auto check_image_type = [&](const ASTNode& image_node) {
@@ -343,8 +368,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
 
         ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];
 
-        Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t);
-
         ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
         if (image_domain_node.is_type<language::type_expression>()) {
           data_type = image_domain_node.m_data_type;
diff --git a/src/language/ASTNodeExpressionBuilder.cpp b/src/language/ASTNodeExpressionBuilder.cpp
index 238d61bd0..eac81ff86 100644
--- a/src/language/ASTNodeExpressionBuilder.cpp
+++ b/src/language/ASTNodeExpressionBuilder.cpp
@@ -19,6 +19,7 @@
 #include <node_processor/LocalNameProcessor.hpp>
 #include <node_processor/NameProcessor.hpp>
 #include <node_processor/OStreamProcessor.hpp>
+#include <node_processor/TupleToVectorProcessor.hpp>
 #include <node_processor/ValueProcessor.hpp>
 #include <node_processor/WhileProcessor.hpp>
 
@@ -38,6 +39,24 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n)
       ASTNodeAffectationExpressionBuilder{n};
     }
 
+  } else if (n.is_type<language::tuple_expression>()) {
+    switch (n.children.size()) {
+    case 1: {
+      n.m_node_processor = std::make_unique<TupleToVectorProcessor<1>>(n);
+      break;
+    }
+    case 2: {
+      n.m_node_processor = std::make_unique<TupleToVectorProcessor<2>>(n);
+      break;
+    }
+    case 3: {
+      n.m_node_processor = std::make_unique<TupleToVectorProcessor<3>>(n);
+      break;
+    }
+    default: {
+      throw parse_error("unexpected error: invalid tuple size", n.begin());
+    }
+    }
   } else if (n.is_type<language::function_definition>()) {
     n.m_node_processor = std::make_unique<FakeProcessor>();
 
diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp
index 34a1b8571..c6d7e3330 100644
--- a/src/language/ASTNodeFunctionExpressionBuilder.cpp
+++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp
@@ -336,11 +336,18 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
   ASTNode& function_expression   = *function_descriptor.definitionNode().children[1];
 
   if (function_expression.is_type<language::expression_list>()) {
-    Assert(function_image_domain.is_type<language::type_expression>());
-    ASTNode& image_domain_node = function_image_domain;
+    if (function_image_domain.is_type<language::vector_type>()) {
+      for (size_t i = 0; i < function_expression.children.size(); ++i) {
+        function_processor->addFunctionExpressionProcessor(
+          this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node,
+                                      *function_expression.children[i]));
+      }
+    } else {
+      ASTNode& image_domain_node = function_image_domain;
 
-    for (size_t i = 0; i < function_expression.children.size(); ++i) {
-      add_component_expression(*function_expression.children[i], *image_domain_node.children[i]);
+      for (size_t i = 0; i < function_expression.children.size(); ++i) {
+        add_component_expression(*function_expression.children[i], *image_domain_node.children[i]);
+      }
     }
   } else {
     add_component_expression(function_expression, function_image_domain);
diff --git a/src/language/PEGGrammar.hpp b/src/language/PEGGrammar.hpp
index 0f216d9aa..d7506b5be 100644
--- a/src/language/PEGGrammar.hpp
+++ b/src/language/PEGGrammar.hpp
@@ -213,7 +213,9 @@ struct logical_or : list_must< logical_and, or_op >{};
 
 struct expression : logical_or {};
 
-struct expression_list : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{};
+struct tuple_expression : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{};
+
+struct expression_list : seq< open_parent, sor< tuple_expression, expression >, plus< if_must< COMMA, sor< tuple_expression, expression > > >, close_parent >{};
 
 struct affect_op : sor< eq_op, multiplyeq_op, divideeq_op, pluseq_op, minuseq_op > {};
 
diff --git a/src/language/node_processor/TupleToVectorProcessor.hpp b/src/language/node_processor/TupleToVectorProcessor.hpp
new file mode 100644
index 000000000..33f6bc305
--- /dev/null
+++ b/src/language/node_processor/TupleToVectorProcessor.hpp
@@ -0,0 +1,45 @@
+#ifndef TUPLE_TO_VECTOR_PROCESSOR_HPP
+#define TUPLE_TO_VECTOR_PROCESSOR_HPP
+
+#include <node_processor/INodeProcessor.hpp>
+
+#include <node_processor/ASTNodeExpressionListProcessor.hpp>
+
+template <size_t N>
+class TupleToVectorProcessor final : public INodeProcessor
+{
+ private:
+  ASTNode& m_node;
+
+  ASTNodeExpressionListProcessor m_list_processor;
+
+ public:
+  DataVariant
+  execute(ExecutionPolicy& exec_policy)
+  {
+    AggregateDataVariant v = std::get<AggregateDataVariant>(m_list_processor.execute(exec_policy));
+
+    Assert(v.size() == N);
+
+    TinyVector<N> x;
+
+    for (size_t i = 0; i < N; ++i) {
+      std::visit(
+        [&](auto&& v) {
+          using ValueT = std::decay_t<decltype(v)>;
+          if constexpr (std::is_arithmetic_v<ValueT>) {
+            x[i] = v;
+          } else {
+            Assert(false, "unexpected value type");
+          }
+        },
+        v[i]);
+    }
+
+    return DataVariant{std::move(x)};
+  }
+
+  TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_list_processor{node} {}
+};
+
+#endif   // TUPLE_TO_VECTOR_PROCESSOR_HPP
-- 
GitLab