From cc7c4fcae2da71bab1762dd0fce9391861a259f2 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Mon, 4 May 2020 18:00:47 +0200
Subject: [PATCH] Clean function evaluation code

- use Kokkos::UniqueToken manager to use proper context in
multi-thread context
- still works only for R^3->R^3 functions

Remains to
- check input/output types
- treat the case when the function returns an R^3 and not a list of 3 R
---
 src/language/MeshModule.cpp | 85 +++++++++----------------------------
 1 file changed, 21 insertions(+), 64 deletions(-)

diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp
index fe0c75d9c..0a0b55941 100644
--- a/src/language/MeshModule.cpp
+++ b/src/language/MeshModule.cpp
@@ -19,35 +19,6 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod
 template <>
 inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
 
-struct TagA
-{
-};
-struct TagC
-{
-};
-
-struct Foo
-{
-  static std::vector<int> m_ids;
-  KOKKOS_INLINE_FUNCTION
-  void
-  operator()(const TagA, const Kokkos::TeamPolicy<>::member_type& team) const
-  {
-    m_ids[team.league_rank()] = team.league_rank() * 3;
-  }
-
-  KOKKOS_INLINE_FUNCTION
-  void
-  operator()(const TagC, const Kokkos::TeamPolicy<>::member_type& team) const
-  {
-    m_ids[team.league_rank()] = team.league_rank() * 2;
-  }
-
-  Foo() {}
-};
-
-inline std::vector<int> Foo::m_ids;
-
 MeshModule::MeshModule()
 {
   this->_addTypeDescriptor(
@@ -67,7 +38,6 @@ MeshModule::MeshModule()
                             std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>,
                                                                      FunctionSymbolId>>(
                               std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
-
                                 [](std::shared_ptr<IMesh> p_mesh,
                                    FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
                                   auto& symbol_table = function_id.symbolTable();
@@ -75,31 +45,15 @@ MeshModule::MeshModule()
                                     *symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
                                   auto& function_context = function_expression.m_symbol_table->context();
 
-                                  // Foo foo;
-                                  // Foo::m_ids.resize(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
-                                  // Kokkos::parallel_for(Kokkos::TeamPolicy<
-                                  //                        TagA>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
-                                  //                        1),
-                                  //                      foo);
-                                  // std::cout << "--------\n" << std::endl;
-
-                                  // for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
-                                  //   std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 3 <<
-                                  //   std::endl;
-                                  // }
-
-                                  // std::cout << "********\n" << std::endl;
-
-                                  // Kokkos::parallel_for(Kokkos::TeamPolicy<
-                                  //                        TagC>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
-                                  //                        1),
-                                  //                      foo);
-                                  // std::cout << "--------\n" << std::endl;
-
-                                  // for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
-                                  //   std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 2 <<
-                                  //   std::endl;
-                                  // }
+                                  const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size();
+                                  Array<ExecutionPolicy> context_list(number_of_threads);
+                                  for (size_t i = 0; i < context_list.size(); ++i) {
+                                    context_list[i] =
+                                      ExecutionPolicy(ExecutionPolicy{},
+                                                      {function_context.id(),
+                                                       std::make_shared<ExecutionPolicy::Context::Values>(
+                                                         function_context.size())});
+                                  }
 
                                   switch (p_mesh->dimension()) {
                                   case 1: {
@@ -117,21 +71,24 @@ MeshModule::MeshModule()
 
                                     NodeValue<TinyVector<3>> xr(given_mesh.connectivity());
 
-                                    ExecutionPolicy::Context context{function_context.id(),
-                                                                     std::make_shared<ExecutionPolicy::Context::Values>(
-                                                                       function_context.size())};
-
-                                    ExecutionPolicy execution_policy;
-                                    ExecutionPolicy context_execution_policy{execution_policy, context};
+                                    using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
+                                    Kokkos::Experimental::UniqueToken<execution_space,
+                                                                      Kokkos::Experimental::UniqueTokenScope::Global>
+                                      tokens;
 
                                     parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
-                                                                              &context_execution_policy](NodeId r) {
-                                      context_execution_policy.currentContext()[0] = given_xr[r];
+                                                                              &tokens](NodeId r) {
+                                      const int32_t t = tokens.acquire();
+
+                                      auto& execution_policy               = context_list[t];
+                                      execution_policy.currentContext()[0] = given_xr[r];
 
-                                      auto&& value = function_expression.execute(context_execution_policy);
+                                      auto&& value = function_expression.execute(execution_policy);
 
                                       AggregateDataVariant& v = std::get<AggregateDataVariant>(value);
                                       xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])};
+
+                                      tokens.release(t);
                                     });
 
                                     return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
-- 
GitLab