From caab4b21c6c7fa8df897dbdbc7f395459b587391 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 6 May 2020 19:16:21 +0200
Subject: [PATCH] Improve the code (design, genericity and readability)

- Rename FunctionAdapter -> PugsFunctionAdapter to improve readability
- PugsFunctionAdapter is now defined in its own file
---
 src/language/MeshModule.cpp          | 105 ++++++---------------------
 src/language/PugsFunctionAdapter.hpp |  90 +++++++++++++++++++++++
 2 files changed, 114 insertions(+), 81 deletions(-)
 create mode 100644 src/language/PugsFunctionAdapter.hpp

diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp
index 29d1ca111..cd987e244 100644
--- a/src/language/MeshModule.cpp
+++ b/src/language/MeshModule.cpp
@@ -2,6 +2,7 @@
 
 #include <language/BuiltinFunctionEmbedder.hpp>
 #include <language/FunctionTable.hpp>
+#include <language/PugsFunctionAdapter.hpp>
 #include <language/SymbolTable.hpp>
 #include <language/TypeDescriptor.hpp>
 #include <language/node_processor/ExecutionPolicy.hpp>
@@ -19,97 +20,39 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod
 template <>
 inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
 
+template <typename T>
+class MeshTransformation;
 template <typename OutputType, typename... InputType>
-class FunctionAdapter
+class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)>
 {
-  static constexpr size_t OutputDimension = OutputType::Dimension;
-
- private:
-  template <typename T, typename... Args>
-  static void
-  _convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context)
-  {
-    context[sizeof...(args)] = t;
-    if constexpr (sizeof...(args) > 0) {
-      _convertArgs(std::forward<Args>(args)..., context);
-    }
-  }
-
-  template <typename... Args>
-  static void
-  convertArgs(ExecutionPolicy::Context& context, const Args&... args)
-  {
-    static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
-    _convertArgs(args..., context);
-  }
-
-  static std::function<OutputType(DataVariant&& result)>
-  _get_result_converter(ASTNodeDataType data_type)
-  {
-    switch (data_type) {
-    case ASTNodeDataType::list_t: {
-      return [](DataVariant&& result) -> OutputType {
-        AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
-        OutputType x;
-        for (size_t i = 0; i < x.dimension(); ++i) {
-          x[i] = std::get<double>(v[i]);
-        }
-        return x;
-      };
-    }
-    case ASTNodeDataType::vector_t: {
-      return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
-    }
-    case ASTNodeDataType::double_t: {
-      if constexpr (OutputDimension == 1) {
-        return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
-      } else {
-        throw UnexpectedError("unexpected data_type");
-      }
-    }
-    default: {
-      throw UnexpectedError("unexpected data_type");
-    }
-    }
-  }
+  static constexpr size_t Dimension = OutputType::Dimension;
+  using Adapter                     = PugsFunctionAdapter<OutputType(InputType...)>;
 
  public:
-  template <size_t Dimension>
-  static inline std::shared_ptr<Mesh<Connectivity<OutputDimension>>>
+  static inline std::shared_ptr<Mesh<Connectivity<Dimension>>>
   transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
   {
-    auto& symbol_table        = function_symbol_id.symbolTable();
-    auto& function_expression = *symbol_table.functionTable()[function_symbol_id.id()].definitionNode().children[1];
-    auto& function_context    = function_expression.m_symbol_table->context();
-
-    ASTNodeDataType t   = function_expression.m_data_type;
-    auto convert_result = _get_result_converter(t);
-
-    const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size();
-    Array<ExecutionPolicy> context_list(number_of_threads);
-    for (size_t i = 0; i < context_list.size(); ++i) {
-      context_list[i] = ExecutionPolicy(ExecutionPolicy{},
-                                        {function_context.id(),
-                                         std::make_shared<ExecutionPolicy::Context::Values>(function_context.size())});
-    }
+    using MeshType             = Mesh<Connectivity<Dimension>>;
+    const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);
 
-    using MeshType                                  = Mesh<Connectivity<Dimension>>;
-    const MeshType& given_mesh                      = dynamic_cast<const MeshType&>(*p_mesh);
-    NodeValue<const TinyVector<Dimension>> given_xr = given_mesh.xr();
+    auto& expression                    = Adapter::getFunctionExpression(function_symbol_id);
+    auto convert_result                 = Adapter::getResultConverter(expression.m_data_type);
+    Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);
 
-    NodeValue<TinyVector<Dimension>> xr(given_mesh.connectivity());
+    NodeValue<const OutputType> given_xr = given_mesh.xr();
+    NodeValue<OutputType> xr(given_mesh.connectivity());
 
     using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
     Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;
 
-    parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, &tokens](NodeId r) {
+    parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) {
       const int32_t t = tokens.acquire();
 
       auto& execution_policy = context_list[t];
 
-      convertArgs(execution_policy.currentContext(), given_xr[r]);
+      Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]);
+      xr[r] = convert_result(expression.execute(execution_policy));
 
-      xr[r] = convert_result(function_expression.execute(execution_policy));
       tokens.release(t);
     });
 
@@ -140,19 +83,19 @@ MeshModule::MeshModule()
                                    FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
                                   switch (p_mesh->dimension()) {
                                   case 1: {
-                                    return FunctionAdapter<TinyVector<1>, TinyVector<1>>::transform<1>(function_id,
-                                                                                                       p_mesh);
+                                    using TransformT = TinyVector<1>(TinyVector<1>);
+                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                   }
                                   case 2: {
-                                    return FunctionAdapter<TinyVector<2>, TinyVector<2>>::transform<2>(function_id,
-                                                                                                       p_mesh);
+                                    using TransformT = TinyVector<2>(TinyVector<2>);
+                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                   }
                                   case 3: {
-                                    return FunctionAdapter<TinyVector<3>, TinyVector<3>>::transform<3>(function_id,
-                                                                                                       p_mesh);
+                                    using TransformT = TinyVector<3>(TinyVector<3>);
+                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                   }
                                   default: {
-                                    throw NormalError("invalid dimension");
+                                    throw NormalError("invalid mesh dimension");
                                   }
                                   }
                                 }}
diff --git a/src/language/PugsFunctionAdapter.hpp b/src/language/PugsFunctionAdapter.hpp
new file mode 100644
index 000000000..262299728
--- /dev/null
+++ b/src/language/PugsFunctionAdapter.hpp
@@ -0,0 +1,90 @@
+#ifndef PUGS_FUNCTION_ADAPTER_HPP
+#define PUGS_FUNCTION_ADAPTER_HPP
+
+#include <language/ASTNode.hpp>
+#include <language/SymbolTable.hpp>
+#include <language/node_processor/ExecutionPolicy.hpp>
+#include <utils/Array.hpp>
+#include <utils/Exceptions.hpp>
+#include <utils/PugsMacros.hpp>
+
+#include <Kokkos_Core.hpp>
+
+template <typename T>
+class PugsFunctionAdapter;
+template <typename OutputType, typename... InputType>
+class PugsFunctionAdapter<OutputType(InputType...)>
+{
+ private:
+  template <typename T, typename... Args>
+  PUGS_INLINE static void
+  _convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context)
+  {
+    context[sizeof...(args)] = t;
+    if constexpr (sizeof...(args) > 0) {
+      _convertArgs(std::forward<Args>(args)..., context);
+    }
+  }
+
+ protected:
+  PUGS_INLINE static auto&
+  getFunctionExpression(FunctionSymbolId function_symbol_id)
+  {
+    return *function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()].definitionNode().children[1];
+  }
+
+  PUGS_INLINE static auto
+  getContextList(const ASTNode& expression)
+  {
+    Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
+    auto& context = expression.m_symbol_table->context();
+
+    for (size_t i = 0; i < context_list.size(); ++i) {
+      context_list[i] =
+        ExecutionPolicy(ExecutionPolicy{},
+                        {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
+    }
+
+    return context_list;
+  }
+
+  template <typename... Args>
+  PUGS_INLINE static void
+  convertArgs(ExecutionPolicy::Context& context, const Args&... args)
+  {
+    static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
+    _convertArgs(args..., context);
+  }
+
+  PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
+  getResultConverter(ASTNodeDataType data_type)
+  {
+    switch (data_type) {
+    case ASTNodeDataType::list_t: {
+      return [](DataVariant&& result) -> OutputType {
+        AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
+        OutputType x;
+        for (size_t i = 0; i < x.dimension(); ++i) {
+          x[i] = std::get<double>(v[i]);
+        }
+        return x;
+      };
+    }
+    case ASTNodeDataType::vector_t: {
+      return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
+    }
+    case ASTNodeDataType::double_t: {
+      if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
+        return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
+      } else {
+        throw UnexpectedError("unexpected data_type");
+      }
+    }
+    default: {
+      throw UnexpectedError("unexpected data_type");
+    }
+    }
+  }
+};
+
+#endif   // PUGS_FUNCTION_ADAPTER_HPP
-- 
GitLab