From 6a15ee1c7f4e59e1c10e98ef56947fdc1cd07d04 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 1 May 2020 17:14:22 +0200
Subject: [PATCH] Define a first mechanism to call user functions in C++

This is a concept proof. It requires many improvements:
- code is crappy
- only works for one kind of function (R^3->R^3 with a result defined
as a list of 3 R)
- no check is performed
- execution context (necessary for multi-threading  is not properly
defined)

Also, one should define an handier way to define such mechanism
---
 ...STNodeBuiltinFunctionExpressionBuilder.cpp |  5 +-
 src/language/BuiltinFunctionEmbedder.hpp      |  8 ----
 src/language/DataVariant.hpp                  |  2 +
 src/language/FunctionSymbolId.hpp             | 48 +++++++++++++++++++
 src/language/FunctionTable.hpp                |  7 ---
 src/language/MeshModule.cpp                   | 45 ++++++++++-------
 .../FunctionArgumentConverter.hpp             | 20 ++++++++
 7 files changed, 102 insertions(+), 33 deletions(-)
 create mode 100644 src/language/FunctionSymbolId.hpp

diff --git a/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp
index 17704a98d..332f957e4 100644
--- a/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp
+++ b/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp
@@ -79,7 +79,10 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData
   auto get_function_argument_to_function_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
     switch (argument_node_sub_data_type.m_data_type) {
     case ASTNodeDataType::function_t: {
-      return std::make_unique<FunctionArgumentConverter<FunctionId, FunctionId>>(argument_number);
+      const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node;
+      auto symbol_table          = parent_node.m_symbol_table;
+
+      return std::make_unique<FunctionArgumentToFunctionSymbolIdConverter>(argument_number, symbol_table);
     }
       // LCOV_EXCL_START
     default: {
diff --git a/src/language/BuiltinFunctionEmbedder.hpp b/src/language/BuiltinFunctionEmbedder.hpp
index 31517c95b..79eb78e4c 100644
--- a/src/language/BuiltinFunctionEmbedder.hpp
+++ b/src/language/BuiltinFunctionEmbedder.hpp
@@ -76,14 +76,6 @@ class BuiltinFunctionEmbedder : public IBuiltinFunctionEmbedder
           } else {
             throw UnexpectedError("unexpected argument types while casting: expecting EmbeddedData");
           }
-        } else if constexpr (std::is_same_v<Ti_Type, FunctionId>) {
-          if constexpr (std::is_same_v<Vi_Type, uint64_t>) {
-            std::get<I>(t) = FunctionId{v_i};
-            throw NotImplementedError(
-              "Should get better descriptor, FunctionId should at least refer to the symbol table.");
-          } else {
-            throw UnexpectedError("unexpected argument types while casting: expecting uint64");
-          }
         } else {
           throw UnexpectedError("Unexpected argument types while casting " + demangle<Vi_Type>() + " -> " +
                                 demangle<Ti_Type>());
diff --git a/src/language/DataVariant.hpp b/src/language/DataVariant.hpp
index fd2e69b9b..04dac5293 100644
--- a/src/language/DataVariant.hpp
+++ b/src/language/DataVariant.hpp
@@ -3,6 +3,7 @@
 
 #include <algebra/TinyVector.hpp>
 #include <language/EmbeddedData.hpp>
+#include <language/FunctionSymbolId.hpp>
 #include <utils/PugsAssert.hpp>
 
 #include <tuple>
@@ -19,6 +20,7 @@ using DataVariant = std::variant<std::monostate,
                                  std::string,
                                  EmbeddedData,
                                  AggregateDataVariant,
+                                 FunctionSymbolId,
                                  TinyVector<1>,
                                  TinyVector<2>,
                                  TinyVector<3>>;
diff --git a/src/language/FunctionSymbolId.hpp b/src/language/FunctionSymbolId.hpp
new file mode 100644
index 000000000..d6f0973f7
--- /dev/null
+++ b/src/language/FunctionSymbolId.hpp
@@ -0,0 +1,48 @@
+#ifndef FUNCTION_SYMBOL_ID_HPP
+#define FUNCTION_SYMBOL_ID_HPP
+
+#include <utils/PugsAssert.hpp>
+#include <utils/PugsMacros.hpp>
+
+#include <cstddef>
+#include <iostream>
+#include <memory>
+
+class SymbolTable;
+class FunctionSymbolId
+{
+ private:
+  uint64_t m_function_id;
+  std::shared_ptr<SymbolTable> m_symbol_table = nullptr;
+
+ public:
+  PUGS_INLINE uint64_t
+  id() const noexcept
+  {
+    return m_function_id;
+  }
+
+  PUGS_INLINE
+  const SymbolTable&
+  symbolTable() const
+  {
+    Assert(m_symbol_table, "FunctionSymbolId is not initialized properly");
+    return *m_symbol_table;
+  }
+
+  friend std::ostream&
+  operator<<(std::ostream& os, const FunctionSymbolId& function_symbol_id)
+  {
+    os << function_symbol_id.m_function_id;
+    return os;
+  }
+
+  FunctionSymbolId() = default;
+  FunctionSymbolId(uint64_t function_id, const std::shared_ptr<SymbolTable>& symbol_table)
+    : m_function_id(function_id), m_symbol_table(symbol_table)
+  {}
+
+  ~FunctionSymbolId() = default;
+};
+
+#endif   // FUNCTION_SYMBOL_ID_HPP
diff --git a/src/language/FunctionTable.hpp b/src/language/FunctionTable.hpp
index 72b19f437..a2b7b95af 100644
--- a/src/language/FunctionTable.hpp
+++ b/src/language/FunctionTable.hpp
@@ -8,13 +8,6 @@
 
 #include <pegtl/position.hpp>
 
-#include <iostream>
-
-struct FunctionId
-{
-  uint64_t m_function_id;
-};
-
 class FunctionDescriptor
 {
   std::unique_ptr<ASTNode> m_domain_mapping_node;
diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp
index 5c1721fa9..696799455 100644
--- a/src/language/MeshModule.cpp
+++ b/src/language/MeshModule.cpp
@@ -2,19 +2,19 @@
 
 #include <language/BuiltinFunctionEmbedder.hpp>
 #include <language/FunctionTable.hpp>
+#include <language/SymbolTable.hpp>
 #include <language/TypeDescriptor.hpp>
+#include <language/node_processor/ExecutionPolicy.hpp>
 #include <mesh/Connectivity.hpp>
 #include <mesh/GmshReader.hpp>
 #include <mesh/Mesh.hpp>
 #include <utils/Exceptions.hpp>
 
-#include <output/VTKWriter.hpp>
-
 template <>
 inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"};
 
 template <>
-inline ASTNodeDataType ast_node_data_type_from<FunctionId> = {ASTNodeDataType::function_t};
+inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
 
 MeshModule::MeshModule()
 {
@@ -34,10 +34,16 @@ MeshModule::MeshModule()
   this
     ->_addBuiltinFunction("transform",
                           std::make_shared<
-                            BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionId>>(
-                            std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionId)>{
+                            BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionSymbolId>>(
+                            std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
+
+                              [](std::shared_ptr<IMesh> p_mesh,
+                                 FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
+                                auto& symbol_table = function_id.symbolTable();
+                                auto& function_expression =
+                                  *symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
+                                auto& function_context = function_expression.m_symbol_table->context();
 
-                              [](std::shared_ptr<IMesh> p_mesh, FunctionId function_id) -> std::shared_ptr<IMesh> {
                                 switch (p_mesh->dimension()) {
                                 case 1: {
                                   throw NotImplementedError("not implemented in 1d");
@@ -48,23 +54,28 @@ MeshModule::MeshModule()
                                   break;
                                 }
                                 case 3: {
-                                  std::cout << "Using function " << function_id.m_function_id << '\n';
-
                                   using MeshType                          = Mesh<Connectivity3D>;
                                   const MeshType& given_mesh              = dynamic_cast<const MeshType&>(*p_mesh);
                                   NodeValue<const TinyVector<3>> given_xr = given_mesh.xr();
 
                                   NodeValue<TinyVector<3>> xr(given_mesh.connectivity());
 
-                                  parallel_for(
-                                    given_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
-                                      const TinyVector<3> shift{0, 0.05, 0.05};
-                                      const auto x   = given_xr[r] - shift;
-                                      const double c = std::cos(0.01 * x[0]);
-                                      const double s = std::sin(0.01 * x[0]);
-                                      const TinyVector<3> transformed{x[0], x[1] * c + x[2] * s, -x[1] * s + x[2] * c};
-                                      xr[r] = transformed + shift;
-                                    });
+                                  parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
+                                                                            &function_context](NodeId r) {
+                                    ExecutionPolicy::Context context{function_context.id(),
+                                                                     std::make_shared<ExecutionPolicy::Context::Values>(
+                                                                       function_context.size())};
+
+                                    ExecutionPolicy execution_policy;
+                                    ExecutionPolicy context_execution_policy{execution_policy, context};
+
+                                    context_execution_policy.currentContext()[0] = given_xr[r];
+
+                                    auto&& value = function_expression.execute(context_execution_policy);
+
+                                    AggregateDataVariant& v = std::get<AggregateDataVariant>(value);
+                                    xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])};
+                                  });
 
                                   return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
                                 }
diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp
index 990fcde28..467bc7b24 100644
--- a/src/language/node_processor/FunctionArgumentConverter.hpp
+++ b/src/language/node_processor/FunctionArgumentConverter.hpp
@@ -43,4 +43,24 @@ class FunctionArgumentConverter final : public IFunctionArgumentConverter
   FunctionArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {}
 };
 
+class FunctionArgumentToFunctionSymbolIdConverter final : public IFunctionArgumentConverter
+{
+ private:
+  size_t m_argument_id;
+  std::shared_ptr<SymbolTable> m_symbol_table;
+
+ public:
+  DataVariant
+  convert(ExecutionPolicy& exec_policy, DataVariant&& value)
+  {
+    exec_policy.currentContext()[m_argument_id] = FunctionSymbolId{std::get<uint64_t>(value), m_symbol_table};
+
+    return {};
+  }
+
+  FunctionArgumentToFunctionSymbolIdConverter(size_t argument_id, const std::shared_ptr<SymbolTable>& symbol_table)
+    : m_argument_id{argument_id}, m_symbol_table{symbol_table}
+  {}
+};
+
 #endif   // FUNCTION_ARGUMENT_CONVERTER_HPP
-- 
GitLab