From a4a979d90ee01d289532915bb7aa7698ae73d066 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 31 Jan 2020 11:19:04 +0100
Subject: [PATCH] Add array subscript operator for R^d elements (for d in
 {1,2,3})

One can now write
``
R^2 x = (3, 2);
R y = x[1]; // y will contain the value 2
``
---
 ...ASTNodeArraySubscriptExpressionBuilder.cpp | 32 +++++++++++++
 ...ASTNodeArraySubscriptExpressionBuilder.hpp | 12 +++++
 src/language/ASTNodeExpressionBuilder.cpp     |  4 ++
 src/language/CMakeLists.txt                   |  1 +
 .../ArraySubscriptProcessor.hpp               | 47 +++++++++++++++++++
 5 files changed, 96 insertions(+)
 create mode 100644 src/language/ASTNodeArraySubscriptExpressionBuilder.cpp
 create mode 100644 src/language/ASTNodeArraySubscriptExpressionBuilder.hpp
 create mode 100644 src/language/node_processor/ArraySubscriptProcessor.hpp

diff --git a/src/language/ASTNodeArraySubscriptExpressionBuilder.cpp b/src/language/ASTNodeArraySubscriptExpressionBuilder.cpp
new file mode 100644
index 000000000..fd214c54d
--- /dev/null
+++ b/src/language/ASTNodeArraySubscriptExpressionBuilder.cpp
@@ -0,0 +1,32 @@
+#include <ASTNodeArraySubscriptExpressionBuilder.hpp>
+
+#include <../algebra/TinyVector.hpp>
+
+#include <node_processor/ArraySubscriptProcessor.hpp>
+
+ASTNodeArraySubscriptExpressionBuilder::ASTNodeArraySubscriptExpressionBuilder(ASTNode& node)
+{
+  auto& array_expression = *node.children[0];
+
+  if (array_expression.m_data_type == ASTNodeDataType::vector_t) {
+    switch (array_expression.m_data_type.dimension()) {
+    case 1: {
+      node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyVector<1>>>(node);
+      break;
+    }
+    case 2: {
+      node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyVector<2>>>(node);
+      break;
+    }
+    case 3: {
+      node.m_node_processor = std::make_unique<ArraySubscriptProcessor<TinyVector<3>>>(node);
+      break;
+    }
+    default: {
+      break;
+    }
+    }
+  } else {
+    throw parse_error("unexpected error: invalid array type", array_expression.begin());
+  }
+}
diff --git a/src/language/ASTNodeArraySubscriptExpressionBuilder.hpp b/src/language/ASTNodeArraySubscriptExpressionBuilder.hpp
new file mode 100644
index 000000000..d93ee2461
--- /dev/null
+++ b/src/language/ASTNodeArraySubscriptExpressionBuilder.hpp
@@ -0,0 +1,12 @@
+#ifndef AST_NODE_ARRAY_SUBSCRIPT_EXPRESSION_BUILDER_HPP
+#define AST_NODE_ARRAY_SUBSCRIPT_EXPRESSION_BUILDER_HPP
+
+#include <ASTNode.hpp>
+
+class ASTNodeArraySubscriptExpressionBuilder
+{
+ public:
+  ASTNodeArraySubscriptExpressionBuilder(ASTNode& node);
+};
+
+#endif   // AST_NODE_ARRAY_SUBSCRIPT_EXPRESSION_BUILDER_HPP
diff --git a/src/language/ASTNodeExpressionBuilder.cpp b/src/language/ASTNodeExpressionBuilder.cpp
index 716926170..cc20137ed 100644
--- a/src/language/ASTNodeExpressionBuilder.cpp
+++ b/src/language/ASTNodeExpressionBuilder.cpp
@@ -3,6 +3,7 @@
 #include <ASTNodeAffectationExpressionBuilder.hpp>
 #include <ASTNodeListAffectationExpressionBuilder.hpp>
 
+#include <ASTNodeArraySubscriptExpressionBuilder.hpp>
 #include <ASTNodeBinaryOperatorExpressionBuilder.hpp>
 #include <ASTNodeFunctionEvaluationExpressionBuilder.hpp>
 #include <ASTNodeIncDecExpressionBuilder.hpp>
@@ -63,6 +64,9 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n)
   } else if (n.is_type<language::function_evaluation>()) {
     ASTNodeFunctionEvaluationExpressionBuilder{n};
 
+  } else if (n.is_type<language::subscript_expression>()) {
+    ASTNodeArraySubscriptExpressionBuilder{n};
+
   } else if (n.is_type<language::real>()) {
     n.m_node_processor = std::make_unique<ValueProcessor>(n);
   } else if (n.is_type<language::integer>()) {
diff --git a/src/language/CMakeLists.txt b/src/language/CMakeLists.txt
index 368fa5a9f..acd8b1e57 100644
--- a/src/language/CMakeLists.txt
+++ b/src/language/CMakeLists.txt
@@ -11,6 +11,7 @@ add_library(
   ASTDotPrinter.cpp
   ASTModulesImporter.cpp
   ASTNodeAffectationExpressionBuilder.cpp
+  ASTNodeArraySubscriptExpressionBuilder.cpp
   ASTNodeBinaryOperatorExpressionBuilder.cpp
   ASTNodeCFunctionExpressionBuilder.cpp
   ASTNodeDataType.cpp
diff --git a/src/language/node_processor/ArraySubscriptProcessor.hpp b/src/language/node_processor/ArraySubscriptProcessor.hpp
new file mode 100644
index 000000000..c0426e2f0
--- /dev/null
+++ b/src/language/node_processor/ArraySubscriptProcessor.hpp
@@ -0,0 +1,47 @@
+#ifndef ARRAY_SUBSCRIPT_PROCESSOR_HPP
+#define ARRAY_SUBSCRIPT_PROCESSOR_HPP
+
+#include <node_processor/INodeProcessor.hpp>
+
+template <typename ArrayTypeT>
+class ArraySubscriptProcessor : public INodeProcessor
+{
+ private:
+  ASTNode& m_array_subscript_expression;
+
+ public:
+  DataVariant
+  execute(ExecutionPolicy& exec_policy)
+  {
+    auto& index_expression = *m_array_subscript_expression.children[1];
+
+    const int64_t index_value = [&](DataVariant&& value_variant) -> int64_t {
+      int64_t index_value = 0;
+      std::visit(
+        [&](auto&& value) {
+          using ValueT = std::decay_t<decltype(value)>;
+          if constexpr (std::is_integral_v<ValueT>) {
+            index_value = value;
+          } else {
+            throw parse_error("unexpected error: invalid index type", std::vector{index_expression.begin()});
+          }
+        },
+        value_variant);
+      return index_value;
+    }(index_expression.execute(exec_policy));
+
+    auto& array_expression = *m_array_subscript_expression.children[0];
+
+    const ArrayTypeT& array = std::get<ArrayTypeT>(array_expression.execute(exec_policy));
+
+    return array[index_value];
+  }
+
+  ArraySubscriptProcessor(ASTNode& array_subscript_expression)
+    : m_array_subscript_expression{array_subscript_expression}
+  {}
+
+  virtual ~ArraySubscriptProcessor() = default;
+};
+
+#endif   // ARRAY_SUBSCRIPT_PROCESSOR_HPP
-- 
GitLab