From 3409f456d7ba86533db3eb128c1348e704c0c15e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Wed, 19 Feb 2025 14:05:21 +0100
Subject: [PATCH] Add load_balance:(Vh)->(Vh) function

Load balancing is performed on a list of discrete functions that live
on a common mesh. Their values are transferred according to the new
partitioning of the mesh.
---
 src/language/modules/SchemeModule.cpp | 11 ++++
 src/mesh/ConnectivityDispatcher.cpp   |  2 -
 src/mesh/ConnectivityDispatcher.hpp   | 64 +++++++++++++++++++++--
 src/mesh/MeshBuilderBase.cpp          | 37 +++++++------
 src/mesh/MeshBuilderBase.hpp          |  6 +++
 src/scheme/CMakeLists.txt             |  3 +-
 src/scheme/LoadBalancer.cpp           | 75 +++++++++++++++++++++++++++
 src/scheme/LoadBalancer.hpp           | 21 ++++++++
 8 files changed, 195 insertions(+), 24 deletions(-)
 create mode 100644 src/scheme/LoadBalancer.cpp
 create mode 100644 src/scheme/LoadBalancer.hpp

diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index 87b48504b..b9b789cfa 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -46,6 +46,7 @@
 #include <scheme/IDiscreteFunctionDescriptor.hpp>
 #include <scheme/InflowBoundaryConditionDescriptor.hpp>
 #include <scheme/InflowListBoundaryConditionDescriptor.hpp>
+#include <scheme/LoadBalancer.hpp>
 #include <scheme/NeumannBoundaryConditionDescriptor.hpp>
 #include <scheme/OutflowBoundaryConditionDescriptor.hpp>
 #include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
@@ -745,6 +746,16 @@ SchemeModule::SchemeModule()
 
                               ));
 
+  this->_addBuiltinFunction("load_balance", std::function(
+
+                                              [](const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>&
+                                                   discrete_function_variant_list)
+                                                -> std::vector<std::shared_ptr<const DiscreteFunctionVariant>> {
+                                                return LoadBalancer{}.balance(discrete_function_variant_list);
+                                              }
+
+                                              ));
+
   MathFunctionRegisterForVh{*this};
 }
 
diff --git a/src/mesh/ConnectivityDispatcher.cpp b/src/mesh/ConnectivityDispatcher.cpp
index d78d400e2..1270dd78f 100644
--- a/src/mesh/ConnectivityDispatcher.cpp
+++ b/src/mesh/ConnectivityDispatcher.cpp
@@ -600,8 +600,6 @@ ConnectivityDispatcher<Dimension>::_buildItemReferenceList()
       return i_rank;
     }();
 
-    Assert(number_of_item_list_sender < parallel::size());
-
     // sending is boundary property
     Array<RefItemListBase::Type> ref_item_list_type{number_of_item_ref_list_per_proc[sender_rank]};
     if (parallel::rank() == sender_rank) {
diff --git a/src/mesh/ConnectivityDispatcher.hpp b/src/mesh/ConnectivityDispatcher.hpp
index 15c83a66c..46932004c 100644
--- a/src/mesh/ConnectivityDispatcher.hpp
+++ b/src/mesh/ConnectivityDispatcher.hpp
@@ -253,19 +253,19 @@ class ConnectivityDispatcher
 
     using MutableDataType = std::remove_const_t<DataType>;
     std::vector<Array<DataType>> item_value_to_send_by_proc(parallel::size());
-    for (size_t i = 0; i < parallel::size(); ++i) {
-      const Array<const ItemId>& item_list = item_list_to_send_by_proc[i];
+    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
+      const Array<const ItemId>& item_list = item_list_to_send_by_proc[i_rank];
       Array<MutableDataType> item_value_list(item_list.size());
       parallel_for(
         item_list.size(),
         PUGS_LAMBDA(const ItemId& item_id) { item_value_list[item_id] = item_value[item_list[item_id]]; });
-      item_value_to_send_by_proc[i] = item_value_list;
+      item_value_to_send_by_proc[i_rank] = item_value_list;
     }
 
     const auto& item_list_to_recv_size_by_proc = this->_dispatchedInfo<item_type>().m_list_to_recv_size_by_proc;
     std::vector<Array<MutableDataType>> recv_item_value_by_proc(parallel::size());
-    for (size_t i = 0; i < parallel::size(); ++i) {
-      recv_item_value_by_proc[i] = Array<MutableDataType>(item_list_to_recv_size_by_proc[i]);
+    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
+      recv_item_value_by_proc[i_rank] = Array<MutableDataType>(item_list_to_recv_size_by_proc[i_rank]);
     }
 
     parallel::exchange(item_value_to_send_by_proc, recv_item_value_by_proc);
@@ -285,6 +285,60 @@ class ConnectivityDispatcher
     return new_item_value;
   }
 
+  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
+  ItemArray<std::remove_const_t<DataType>, item_type, ConnectivityPtr>
+  dispatch(ItemArray<DataType, item_type, ConnectivityPtr> item_array) const
+  {
+    using ItemId = ItemIdT<item_type>;
+
+    Assert(m_dispatched_connectivity.use_count() > 0, "cannot dispatch quantity before connectivity");
+
+    const auto& item_list_to_send_by_proc = this->_dispatchedInfo<item_type>().m_list_to_send_by_proc;
+
+    const size_t size_of_arrays = item_array.sizeOfArrays();
+
+    using MutableDataType = std::remove_const_t<DataType>;
+    std::vector<Array<DataType>> item_array_to_send_by_proc(parallel::size());
+    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
+      const Array<const ItemId>& item_list = item_list_to_send_by_proc[i_rank];
+      Array<MutableDataType> item_array_list(item_list.size() * item_array.sizeOfArrays());
+      parallel_for(
+        item_list.size(), PUGS_LAMBDA(const ItemId& i) {
+          const size_t j   = i * size_of_arrays;
+          const auto array = item_array[item_list[i]];
+          for (size_t k = 0; k < size_of_arrays; ++k) {
+            item_array_list[j + k] = array[k];
+          }
+        });
+      item_array_to_send_by_proc[i_rank] = item_array_list;
+    }
+
+    const auto& item_list_to_recv_size_by_proc = this->_dispatchedInfo<item_type>().m_list_to_recv_size_by_proc;
+    std::vector<Array<MutableDataType>> recv_item_array_by_proc(parallel::size());
+    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
+      recv_item_array_by_proc[i_rank] = Array<MutableDataType>(item_list_to_recv_size_by_proc[i_rank] * size_of_arrays);
+    }
+
+    parallel::exchange(item_array_to_send_by_proc, recv_item_array_by_proc);
+
+    const auto& recv_item_id_correspondance_by_proc =
+      this->_dispatchedInfo<item_type>().m_recv_id_correspondance_by_proc;
+    ItemArray<MutableDataType, item_type> new_item_array(*m_dispatched_connectivity, size_of_arrays);
+    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
+      const auto& recv_item_id_correspondance = recv_item_id_correspondance_by_proc[i_rank];
+      const auto& recv_item_array             = recv_item_array_by_proc[i_rank];
+      parallel_for(
+        recv_item_id_correspondance.size(), PUGS_LAMBDA(size_t i) {
+          const size_t j = i * size_of_arrays;
+          auto array     = new_item_array[recv_item_id_correspondance[i]];
+          for (size_t k = 0; k < size_of_arrays; ++k) {
+            array[k] = recv_item_array[j + k];
+          }
+        });
+    }
+    return new_item_array;
+  }
+
   ConnectivityDispatcher(const ConnectivityType& mesh);
   ConnectivityDispatcher(const ConnectivityDispatcher&) = delete;
   ~ConnectivityDispatcher()                             = default;
diff --git a/src/mesh/MeshBuilderBase.cpp b/src/mesh/MeshBuilderBase.cpp
index 861050ed7..4a0a010ce 100644
--- a/src/mesh/MeshBuilderBase.cpp
+++ b/src/mesh/MeshBuilderBase.cpp
@@ -18,31 +18,36 @@ template <size_t Dimension>
 void
 MeshBuilderBase::_dispatch()
 {
-  if (parallel::size() == 1) {
-    return;
-  }
-
   using ConnectivityType = Connectivity<Dimension>;
   using Rd               = TinyVector<Dimension>;
   using MeshType         = Mesh<Dimension>;
 
-  if (not m_mesh) {
-    ConnectivityDescriptor descriptor;
-    std::shared_ptr connectivity = ConnectivityType::build(descriptor);
-    NodeValue<Rd> xr;
-    m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(connectivity, xr));
-  }
+  if (parallel::size() == 1) {
+    const MeshType& mesh = *(m_mesh->get<const MeshType>());
+
+    // force "creation" of a new mesh to avoid different
+    // parallel/sequential behaviors: if the mesh should change in
+    // parallel, is also changes in sequential.
+    m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(mesh.shared_connectivity(), mesh.xr()));
+  } else {
+    if (not m_mesh) {
+      ConnectivityDescriptor descriptor;
+      std::shared_ptr connectivity = ConnectivityType::build(descriptor);
+      NodeValue<Rd> xr;
+      m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(connectivity, xr));
+    }
 
-  const MeshType& mesh = *(m_mesh->get<const MeshType>());
+    const MeshType& mesh = *(m_mesh->get<const MeshType>());
 
-  auto p_dispatcher = std::make_shared<const ConnectivityDispatcher<Dimension>>(mesh.connectivity());
+    auto p_dispatcher = std::make_shared<const ConnectivityDispatcher<Dimension>>(mesh.connectivity());
 
-  m_connectivity_dispatcher = std::make_shared<ConnectivityDispatcherVariant>(p_dispatcher);
+    m_connectivity_dispatcher = std::make_shared<ConnectivityDispatcherVariant>(p_dispatcher);
 
-  std::shared_ptr dispatched_connectivity = p_dispatcher->dispatchedConnectivity();
-  NodeValue<Rd> dispatched_xr             = p_dispatcher->dispatch(mesh.xr());
+    std::shared_ptr dispatched_connectivity = p_dispatcher->dispatchedConnectivity();
+    NodeValue<Rd> dispatched_xr             = p_dispatcher->dispatch(mesh.xr());
 
-  m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(dispatched_connectivity, dispatched_xr));
+    m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(dispatched_connectivity, dispatched_xr));
+  }
 }
 
 template void MeshBuilderBase::_dispatch<1>();
diff --git a/src/mesh/MeshBuilderBase.hpp b/src/mesh/MeshBuilderBase.hpp
index fb811acf4..7b5f34d92 100644
--- a/src/mesh/MeshBuilderBase.hpp
+++ b/src/mesh/MeshBuilderBase.hpp
@@ -19,6 +19,12 @@ class MeshBuilderBase
   void _checkMesh() const;
 
  public:
+  std::shared_ptr<const ConnectivityDispatcherVariant>
+  connectivityDispatcher() const
+  {
+    return m_connectivity_dispatcher;
+  }
+
   std::shared_ptr<const MeshVariant>
   mesh() const
   {
diff --git a/src/scheme/CMakeLists.txt b/src/scheme/CMakeLists.txt
index dadbe73bb..a4642f38f 100644
--- a/src/scheme/CMakeLists.txt
+++ b/src/scheme/CMakeLists.txt
@@ -3,13 +3,14 @@
 add_library(
   PugsScheme
   AcousticSolver.cpp
-  HyperelasticSolver.cpp
   DiscreteFunctionIntegrator.cpp
   DiscreteFunctionInterpoler.cpp
   DiscreteFunctionUtils.cpp
   DiscreteFunctionVectorIntegrator.cpp
   DiscreteFunctionVectorInterpoler.cpp
   FluxingAdvectionSolver.cpp
+  HyperelasticSolver.cpp
+  LoadBalancer.cpp
 )
 
 target_link_libraries(
diff --git a/src/scheme/LoadBalancer.cpp b/src/scheme/LoadBalancer.cpp
new file mode 100644
index 000000000..0eefbeff2
--- /dev/null
+++ b/src/scheme/LoadBalancer.cpp
@@ -0,0 +1,75 @@
+#include <scheme/LoadBalancer.hpp>
+
+#include <mesh/ConnectivityDispatcher.hpp>
+#include <mesh/ConnectivityDispatcherVariant.hpp>
+#include <mesh/MeshBalancer.hpp>
+#include <mesh/MeshVariant.hpp>
+#include <scheme/DiscreteFunctionUtils.hpp>
+#include <scheme/DiscreteFunctionVariant.hpp>
+#include <utils/Messenger.hpp>
+
+std::vector<std::shared_ptr<const DiscreteFunctionVariant>>
+LoadBalancer::balance(const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_list)
+{
+  if (not hasSameMesh(discrete_function_list)) {
+    throw NormalError("discrete functions are not defined on the same mesh");
+  }
+  auto mesh_v = getCommonMesh(discrete_function_list);
+  Assert(mesh_v.use_count() > 0);
+
+  std::vector<std::shared_ptr<const DiscreteFunctionVariant>> balanced_discrete_function_list;
+
+  MeshBalancer mesh_balancer(mesh_v);
+
+  auto p_balanced_mesh_v = mesh_balancer.mesh();
+
+  if (parallel::size() == 1) {
+    std::visit(
+      [&balanced_discrete_function_list, &discrete_function_list](auto&& p_balanced_mesh) {
+        for (size_t i_discrete_function = 0; i_discrete_function < discrete_function_list.size();
+             ++i_discrete_function) {
+          std::visit(
+            [&balanced_discrete_function_list, p_balanced_mesh](auto&& discrete_function) {
+              using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
+              if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
+                balanced_discrete_function_list.push_back(std::make_shared<DiscreteFunctionVariant>(
+                  DiscreteFunctionT{p_balanced_mesh, discrete_function.cellValues()}));
+              } else {
+                balanced_discrete_function_list.push_back(std::make_shared<DiscreteFunctionVariant>(
+                  DiscreteFunctionT{p_balanced_mesh, discrete_function.cellArrays()}));
+              }
+            },
+            discrete_function_list[i_discrete_function]->discreteFunction());
+        }
+      },
+      p_balanced_mesh_v->variant());
+  } else {
+    std::visit(
+      [&mesh_balancer, &balanced_discrete_function_list, &discrete_function_list](auto&& p_balanced_mesh) {
+        using MeshType             = mesh_type_t<decltype(p_balanced_mesh)>;
+        constexpr size_t Dimension = MeshType::Dimension;
+        const auto& dispatcher     = mesh_balancer.connectivityDispatcher()->get<Dimension>();
+
+        for (size_t i_discrete_function = 0; i_discrete_function < discrete_function_list.size();
+             ++i_discrete_function) {
+          std::visit(
+            [&balanced_discrete_function_list, &dispatcher, &p_balanced_mesh](auto&& discrete_function) {
+              using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
+              if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
+                const auto& dispatched_cell_value = dispatcher->template dispatch(discrete_function.cellValues());
+                balanced_discrete_function_list.push_back(
+                  std::make_shared<DiscreteFunctionVariant>(DiscreteFunctionT{p_balanced_mesh, dispatched_cell_value}));
+              } else {
+                const auto& dispatched_cell_array = dispatcher->template dispatch(discrete_function.cellArrays());
+                balanced_discrete_function_list.push_back(
+                  std::make_shared<DiscreteFunctionVariant>(DiscreteFunctionT{p_balanced_mesh, dispatched_cell_array}));
+              }
+            },
+            discrete_function_list[i_discrete_function]->discreteFunction());
+        }
+      },
+      p_balanced_mesh_v->variant());
+  }
+
+  return balanced_discrete_function_list;
+}
diff --git a/src/scheme/LoadBalancer.hpp b/src/scheme/LoadBalancer.hpp
new file mode 100644
index 000000000..465db1bb6
--- /dev/null
+++ b/src/scheme/LoadBalancer.hpp
@@ -0,0 +1,21 @@
+#ifndef LOAD_BALANCER_HPP
+#define LOAD_BALANCER_HPP
+
+#include <memory>
+#include <tuple>
+#include <vector>
+
+class MeshVariant;
+class DiscreteFunctionVariant;
+
+class LoadBalancer
+{
+ public:
+  std::vector<std::shared_ptr<const DiscreteFunctionVariant>> balance(
+    const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_list);
+
+  LoadBalancer()  = default;
+  ~LoadBalancer() = default;
+};
+
+#endif   // LOAD_BALANCER_HPP
-- 
GitLab