From af4ba50fecc1b4929c99654d5a54f032fd152c04 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Thu, 4 Mar 2021 18:31:27 +0100
Subject: [PATCH] Allow compound types to be composed of type_id

This means that now the following code is valid
```
import scheme;
import mesh;
import math;

let m:mesh, m = cartesian1dMesh(-1,1,100);

let pi:R, pi = acos(-1);

let f: R^1 -> R, x -> x[0]*x[0];
let g: R^1 -> R, x -> sin(pi*x[0]);

let (fh,gh): Vh*Vh, (fh,gh) = (interpolate(m, P0(), f), interpolate(m, P0(), g));
```
This allows to define builtin functions that return composed
results. For instance the result of the calculation of a time step of
Lagrangian schemes. One can now return (rho,u,E,...).
---
 src/language/ast/ASTNodeDataTypeBuilder.cpp   | 22 ++++--
 ...STNodeListAffectationExpressionBuilder.cpp | 42 ++++++++++-
 .../node_processor/AffectationProcessor.hpp   | 73 +++++++++----------
 3 files changed, 90 insertions(+), 47 deletions(-)

diff --git a/src/language/ast/ASTNodeDataTypeBuilder.cpp b/src/language/ast/ASTNodeDataTypeBuilder.cpp
index 018bc94f7..b5f65437c 100644
--- a/src/language/ast/ASTNodeDataTypeBuilder.cpp
+++ b/src/language/ast/ASTNodeDataTypeBuilder.cpp
@@ -333,6 +333,22 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
           value_type = getMatrixDataType(image_node);
         } else if (image_node.is_type<language::string_type>()) {
           value_type = ASTNodeDataType::build<ASTNodeDataType::string_t>();
+        } else if (image_node.is_type<language::type_name_id>()) {
+          const std::string& type_name_id = image_node.string();
+
+          auto& symbol_table = *image_node.m_symbol_table;
+
+          const auto [i_type_symbol, found] = symbol_table.find(type_name_id, image_node.begin());
+          if (not found) {
+            throw ParseError("undefined type identifier", std::vector{image_node.begin()});
+          } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) {
+            std::ostringstream os;
+            os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '"
+               << dataTypeName(i_type_symbol->attributes().dataType()) << '\'';
+            throw ParseError(os.str(), std::vector{image_node.begin()});
+          }
+
+          value_type = ASTNodeDataType::build<ASTNodeDataType::type_id_t>(type_name_id);
         }
 
         // LCOV_EXCL_START
@@ -348,7 +364,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
       }
       n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::typename_t>(
         ASTNodeDataType::build<ASTNodeDataType::list_t>(sub_data_type_list));
-
     } else if (n.is_type<language::for_post>() or n.is_type<language::for_init>() or
                n.is_type<language::for_statement_block>()) {
       n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>();
@@ -361,13 +376,11 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
 
       const ASTNode& test_node = *n.children[0];
       ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::build<ASTNodeDataType::bool_t>()};
-
     } else if (n.is_type<language::do_while_statement>()) {
       n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::void_t>();
 
       const ASTNode& test_node = *n.children[1];
       ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::build<ASTNodeDataType::bool_t>()};
-
     } else if (n.is_type<language::unary_not>() or n.is_type<language::unary_minus>()) {
       auto& operator_repository = OperatorRepository::instance();
 
@@ -394,7 +407,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
                 << rang::style::reset;
         throw ParseError(message.str(), n.begin());
       }
-
     } else if (n.is_type<language::unary_plusplus>() or n.is_type<language::unary_minusminus>() or
                n.is_type<language::post_plusplus>() or n.is_type<language::post_minusminus>()) {
       auto& operator_repository = OperatorRepository::instance();
@@ -428,7 +440,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
                 << rang::style::reset;
         throw ParseError(message.str(), n.begin());
       }
-
     } else if (n.is_type<language::plus_op>() or n.is_type<language::minus_op>() or
                n.is_type<language::multiply_op>() or n.is_type<language::divide_op>() or
                n.is_type<language::lesser_op>() or n.is_type<language::lesser_or_eq_op>() or
@@ -496,7 +507,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
                 << "note: incompatible operand types " << dataTypeName(type_0) << " and " << dataTypeName(type_1);
         throw ParseError(message.str(), n.begin());
       }
-
     } else if (n.is_type<language::function_evaluation>()) {
       if (n.children[0]->m_data_type == ASTNodeDataType::function_t) {
         const std::string& function_name = n.children[0]->string();
diff --git a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp
index b135bbcd3..6b53db72a 100644
--- a/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp
+++ b/src/language/ast/ASTNodeListAffectationExpressionBuilder.cpp
@@ -6,6 +6,9 @@
 #include <language/utils/ASTNodeNaturalConversionChecker.hpp>
 #include <language/utils/ParseError.hpp>
 
+#include <language/utils/AffectationMangler.hpp>
+#include <language/utils/OperatorRepository.hpp>
+
 template <typename OperatorT>
 void
 ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
@@ -115,6 +118,25 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
     }
   };
 
+  auto add_affectation_processor_for_embedded_data = [&](const ASTNodeSubDataType& node_sub_data_type) {
+    if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
+      switch (node_sub_data_type.m_data_type) {
+      case ASTNodeDataType::type_id_t: {
+        list_affectation_processor->template add<EmbeddedData, EmbeddedData>(value_node);
+        break;
+      }
+        // LCOV_EXCL_START
+      default: {
+        throw ParseError("unexpected error:invalid operand type for embedded data affectation",
+                         std::vector{node_sub_data_type.m_parent_node.begin()});
+      }
+        // LCOV_EXCL_STOP
+      }
+    } else {
+      throw ParseError("unexpected error: undefined operator type for string affectation", std::vector{m_node.begin()});
+    }
+  };
+
   auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) {
     if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
       switch (node_sub_data_type.m_data_type) {
@@ -192,6 +214,10 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
       add_affectation_processor_for_data(double{}, node_sub_data_type);
       break;
     }
+    case ASTNodeDataType::type_id_t: {
+      add_affectation_processor_for_embedded_data(node_sub_data_type);
+      break;
+    }
     case ASTNodeDataType::vector_t: {
       switch (value_type.dimension()) {
       case 1: {
@@ -251,7 +277,21 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor(
 
   ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(rhs_node_sub_data_type, value_node.m_data_type);
 
-  add_affectation_processor_for_value(value_node.m_data_type, rhs_node_sub_data_type);
+  const std::string affectation_name =
+    affectationMangler<language::eq_op>(value_node.m_data_type, rhs_node_sub_data_type.m_data_type);
+
+  const auto& optional_processor_builder =
+    OperatorRepository::instance().getAffectationProcessorBuilder(affectation_name);
+
+  if (optional_processor_builder.has_value()) {
+    add_affectation_processor_for_value(value_node.m_data_type, rhs_node_sub_data_type);
+  } else {
+    std::ostringstream error_message;
+    error_message << "undefined affectation type: ";
+    error_message << rang::fgB::red << affectation_name << rang::fg::reset;
+
+    throw ParseError(error_message.str(), std::vector{m_node.children[0]->begin()});
+  }
 }
 
 template <typename OperatorT>
diff --git a/src/language/node_processor/AffectationProcessor.hpp b/src/language/node_processor/AffectationProcessor.hpp
index db5d0fc75..07a52de44 100644
--- a/src/language/node_processor/AffectationProcessor.hpp
+++ b/src/language/node_processor/AffectationProcessor.hpp
@@ -121,7 +121,7 @@ class AffectationExecutor final : public IAffectationExecutor
           }
         } else {
           if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
-            if constexpr (std::is_convertible_v<ValueT, DataT>) {
+            if constexpr (std::is_convertible_v<DataT, ValueT>) {
               m_lhs = std::get<DataT>(rhs);
             } else if constexpr (std::is_same_v<DataT, AggregateDataVariant>) {
               const AggregateDataVariant& v = std::get<AggregateDataVariant>(rhs);
@@ -275,11 +275,8 @@ class MatrixComponentAffectationExecutor final : public IAffectationExecutor
         }
       } else {
         if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
-          if constexpr (std::is_same_v<ValueT, DataT>) {
-            m_lhs_array(index0_value, index1_value) = std::get<DataT>(rhs);
-          } else {
-            m_lhs_array(index0_value, index1_value) = static_cast<ValueT>(std::get<DataT>(rhs));
-          }
+          static_assert(std::is_convertible_v<DataT, ValueT>, "unexpected types");
+          m_lhs_array(index0_value, index1_value) = static_cast<ValueT>(std::get<DataT>(rhs));
         } else {
           AffOp<OperatorT>().eval(m_lhs_array(index0_value, index1_value), std::get<DataT>(rhs));
         }
@@ -354,11 +351,8 @@ class VectorComponentAffectationExecutor final : public IAffectationExecutor
         }
       } else {
         if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
-          if constexpr (std::is_same_v<ValueT, DataT>) {
-            m_lhs_array[index_value] = std::get<DataT>(rhs);
-          } else {
-            m_lhs_array[index_value] = static_cast<ValueT>(std::get<DataT>(rhs));
-          }
+          static_assert(std::is_convertible_v<DataT, ValueT>, "incompatible data and value types");
+          m_lhs_array[index_value] = static_cast<ValueT>(std::get<DataT>(rhs));
         } else {
           AffOp<OperatorT>().eval(m_lhs_array[index_value], std::get<DataT>(rhs));
         }
@@ -375,16 +369,8 @@ class AffectationProcessor final : public INodeProcessor
 
   std::unique_ptr<IAffectationExecutor> m_affectation_executor;
 
- public:
-  DataVariant
-  execute(ExecutionPolicy& exec_policy)
-  {
-    m_affectation_executor->affect(exec_policy, m_rhs_node.execute(exec_policy));
-
-    return {};
-  }
-
-  AffectationProcessor(ASTNode& lhs_node, ASTNode& rhs_node) : m_rhs_node{rhs_node}
+  std::unique_ptr<IAffectationExecutor>
+  _buildAffectationExecutor(ASTNode& lhs_node)
   {
     if (lhs_node.is_type<language::name>()) {
       const std::string& symbol = lhs_node.string();
@@ -397,7 +383,7 @@ class AffectationProcessor final : public INodeProcessor
       }
 
       using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>;
-      m_affectation_executor     = std::make_unique<AffectationExecutorT>(lhs_node, std::get<ValueT>(value));
+      return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ValueT>(value));
     } else if (lhs_node.is_type<language::subscript_expression>()) {
       auto& array_expression = *lhs_node.children[0];
       Assert(array_expression.is_type<language::name>());
@@ -420,9 +406,8 @@ class AffectationProcessor final : public INodeProcessor
             value = ArrayTypeT{};
           }
           using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
-          m_affectation_executor =
-            std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
-          break;
+
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
         }
         case 2: {
           using ArrayTypeT = TinyVector<2>;
@@ -430,9 +415,8 @@ class AffectationProcessor final : public INodeProcessor
             value = ArrayTypeT{};
           }
           using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
-          m_affectation_executor =
-            std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
-          break;
+
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
         }
         case 3: {
           using ArrayTypeT = TinyVector<3>;
@@ -440,9 +424,8 @@ class AffectationProcessor final : public INodeProcessor
             value = ArrayTypeT{};
           }
           using AffectationExecutorT = VectorComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
-          m_affectation_executor =
-            std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
-          break;
+
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression);
         }
           // LCOV_EXCL_START
         default: {
@@ -465,9 +448,8 @@ class AffectationProcessor final : public INodeProcessor
           }
           using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
 
-          m_affectation_executor = std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value),
-                                                                          index0_expression, index1_expression);
-          break;
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression,
+                                                        index1_expression);
         }
         case 2: {
           using ArrayTypeT = TinyMatrix<2>;
@@ -476,9 +458,8 @@ class AffectationProcessor final : public INodeProcessor
           }
           using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
 
-          m_affectation_executor = std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value),
-                                                                          index0_expression, index1_expression);
-          break;
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression,
+                                                        index1_expression);
         }
         case 3: {
           using ArrayTypeT = TinyMatrix<3>;
@@ -487,9 +468,8 @@ class AffectationProcessor final : public INodeProcessor
           }
           using AffectationExecutorT = MatrixComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
 
-          m_affectation_executor = std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value),
-                                                                          index0_expression, index1_expression);
-          break;
+          return std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index0_expression,
+                                                        index1_expression);
         }
           // LCOV_EXCL_START
         default: {
@@ -508,6 +488,19 @@ class AffectationProcessor final : public INodeProcessor
       // LCOV_EXCL_STOP
     }
   }
+
+ public:
+  DataVariant
+  execute(ExecutionPolicy& exec_policy)
+  {
+    m_affectation_executor->affect(exec_policy, m_rhs_node.execute(exec_policy));
+
+    return {};
+  }
+
+  AffectationProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
+    : m_rhs_node{rhs_node}, m_affectation_executor{this->_buildAffectationExecutor(lhs_node)}
+  {}
 };
 
 class AffectationToDataVariantProcessorBase : public INodeProcessor
-- 
GitLab