From f92a97b5e7e82b0c47f5a0851e0379816161e5a0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Wed, 26 Feb 2025 17:16:36 +0100
Subject: [PATCH] Force connectivity reconstruction on load balancing for
 sequential

This is intended to ensure that the code behaves the same for parallel
and sequential calculations
---
 src/mesh/CartesianMeshBuilder.cpp |  2 +-
 src/mesh/GmshReader.cpp           |  6 +++---
 src/mesh/MeshBalancer.cpp         |  2 +-
 src/mesh/MeshBuilderBase.cpp      | 12 +++++++-----
 src/mesh/MeshBuilderBase.hpp      |  9 ++++++++-
 5 files changed, 20 insertions(+), 11 deletions(-)

diff --git a/src/mesh/CartesianMeshBuilder.cpp b/src/mesh/CartesianMeshBuilder.cpp
index 4ae7410ce..d0b34bb12 100644
--- a/src/mesh/CartesianMeshBuilder.cpp
+++ b/src/mesh/CartesianMeshBuilder.cpp
@@ -144,7 +144,7 @@ CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<Dimension>& a,
 
     this->_buildCartesianMesh(corner0, corner1, size);
   }
-  this->_dispatch<Dimension>();
+  this->_dispatch<Dimension>(DispatchType::initial);
 }
 
 template CartesianMeshBuilder::CartesianMeshBuilder(const TinyVector<1>&,
diff --git a/src/mesh/GmshReader.cpp b/src/mesh/GmshReader.cpp
index bb409dd28..a2d147da4 100644
--- a/src/mesh/GmshReader.cpp
+++ b/src/mesh/GmshReader.cpp
@@ -1086,15 +1086,15 @@ GmshReader::GmshReader(const std::string& filename) : m_filename(filename)
 
     switch (mesh_dimension) {
     case 1: {
-      this->_dispatch<1>();
+      this->_dispatch<1>(DispatchType::initial);
       break;
     }
     case 2: {
-      this->_dispatch<2>();
+      this->_dispatch<2>(DispatchType::initial);
       break;
     }
     case 3: {
-      this->_dispatch<3>();
+      this->_dispatch<3>(DispatchType::initial);
       break;
     }
     default: {
diff --git a/src/mesh/MeshBalancer.cpp b/src/mesh/MeshBalancer.cpp
index d076c7d48..23e4759b0 100644
--- a/src/mesh/MeshBalancer.cpp
+++ b/src/mesh/MeshBalancer.cpp
@@ -11,7 +11,7 @@ MeshBalancer::MeshBalancer(const std::shared_ptr<const MeshVariant>& initial_mes
     [this](auto&& mesh) {
       using MeshType = mesh_type_t<decltype(mesh)>;
 
-      this->_dispatch<MeshType::Dimension>();
+      this->_dispatch<MeshType::Dimension>(DispatchType::balance);
     },
     initial_mesh->variant());
 }
diff --git a/src/mesh/MeshBuilderBase.cpp b/src/mesh/MeshBuilderBase.cpp
index 4a0a010ce..b8efb24e9 100644
--- a/src/mesh/MeshBuilderBase.cpp
+++ b/src/mesh/MeshBuilderBase.cpp
@@ -16,13 +16,13 @@
 
 template <size_t Dimension>
 void
-MeshBuilderBase::_dispatch()
+MeshBuilderBase::_dispatch(const DispatchType dispatch_type)
 {
   using ConnectivityType = Connectivity<Dimension>;
   using Rd               = TinyVector<Dimension>;
   using MeshType         = Mesh<Dimension>;
 
-  if (parallel::size() == 1) {
+  if ((parallel::size() == 1) and (dispatch_type == DispatchType::initial)) {
     const MeshType& mesh = *(m_mesh->get<const MeshType>());
 
     // force "creation" of a new mesh to avoid different
@@ -30,6 +30,8 @@ MeshBuilderBase::_dispatch()
     // parallel, is also changes in sequential.
     m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(mesh.shared_connectivity(), mesh.xr()));
   } else {
+    Assert((parallel::size() >= 1) or (dispatch_type == DispatchType::balance));
+
     if (not m_mesh) {
       ConnectivityDescriptor descriptor;
       std::shared_ptr connectivity = ConnectivityType::build(descriptor);
@@ -50,9 +52,9 @@ MeshBuilderBase::_dispatch()
   }
 }
 
-template void MeshBuilderBase::_dispatch<1>();
-template void MeshBuilderBase::_dispatch<2>();
-template void MeshBuilderBase::_dispatch<3>();
+template void MeshBuilderBase::_dispatch<1>(const DispatchType);
+template void MeshBuilderBase::_dispatch<2>(const DispatchType);
+template void MeshBuilderBase::_dispatch<3>(const DispatchType);
 
 template <size_t Dimension>
 void
diff --git a/src/mesh/MeshBuilderBase.hpp b/src/mesh/MeshBuilderBase.hpp
index 7b5f34d92..a64105f3a 100644
--- a/src/mesh/MeshBuilderBase.hpp
+++ b/src/mesh/MeshBuilderBase.hpp
@@ -8,12 +8,19 @@ class ConnectivityDispatcherVariant;
 
 class MeshBuilderBase
 {
+ public:
+  enum class DispatchType : uint8_t
+  {
+    initial,
+    balance
+  };
+
  protected:
   std::shared_ptr<const MeshVariant> m_mesh;
   std::shared_ptr<const ConnectivityDispatcherVariant> m_connectivity_dispatcher;
 
   template <size_t Dimension>
-  void _dispatch();
+  void _dispatch(const DispatchType balance);
 
   template <size_t Dimension>
   void _checkMesh() const;
-- 
GitLab