From 99a79a311798b64e37f335ebc0cb25e6b8296c3c Mon Sep 17 00:00:00 2001
From: chantrait <teddy.chantrait@cea.fr>
Date: Wed, 27 Jul 2022 14:16:35 +0200
Subject: [PATCH] Add an internal MPI communicator to allow MPMD coupling

---
 src/utils/Messenger.cpp   | 46 ++++++++++++++++++++++++++++++++++-----
 src/utils/Messenger.hpp   | 46 ++++++++++++++++++++++++---------------
 src/utils/Partitioner.cpp |  5 +++--
 src/utils/PugsUtils.cpp   |  7 ++++++
 4 files changed, 80 insertions(+), 24 deletions(-)

diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp
index e3fff49d0..66f178d3e 100644
--- a/src/utils/Messenger.cpp
+++ b/src/utils/Messenger.cpp
@@ -31,18 +31,54 @@ Messenger::Messenger([[maybe_unused]] int& argc, [[maybe_unused]] char* argv[])
 #ifdef PUGS_HAS_MPI
   MPI_Init(&argc, &argv);
 
-  m_rank = []() {
+  const char* coupled_color = std::getenv("PUGS_COUPLED_COLOR");
+  if (coupled_color == NULL) {
+    m_comm_world_pugs = MPI_COMM_WORLD;
+  } else {
+    int color = atoi(coupled_color);
+    int global_rank;
+    int global_size;
+    int key = 0;
+
+    MPI_Comm_rank(MPI_COMM_WORLD, &global_rank);
+    MPI_Comm_size(MPI_COMM_WORLD, &global_size);
+
+    auto res = MPI_Comm_split(MPI_COMM_WORLD, color, key, &m_comm_world_pugs);
+    if (res) {
+      MPI_Abort(MPI_COMM_WORLD, res);
+    }
+
+    int local_rank;
+    int local_size;
+
+    MPI_Comm_rank(m_comm_world_pugs, &local_rank);
+    MPI_Comm_size(m_comm_world_pugs, &local_size);
+
+    std::cout << "----------------- " << rang::fg::green << "pugs coupled info " << rang::fg::reset
+              << " ----------------------" << '\n';
+    std::cout << "Coupling mode activated";
+    std::cout << "\n\t Global size: " << global_size;
+    std::cout << "\n\t Global rank: " << global_rank + 1;
+    std::cout << "\n\t Coupled color: " << color;
+    std::cout << "\n\t local size: " << local_size;
+    std::cout << "\n\t local rank: " << local_rank + 1 << std::endl;
+    std::cout << "---------------------------------------" << '\n';
+  }
+
+  m_rank = [&]() {
     int rank;
-    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+    MPI_Comm_rank(m_comm_world_pugs, &rank);
     return rank;
   }();
 
-  m_size = []() {
+  m_size = [&]() {
     int size = 0;
-    MPI_Comm_size(MPI_COMM_WORLD, &size);
+    MPI_Comm_size(m_comm_world_pugs, &size);
     return size;
   }();
 
+  std::cout << "pugs process " << m_rank + 1 << "/" << m_size << std::endl;
+
   if (m_rank != 0) {
     // LCOV_EXCL_START
     std::cout.setstate(std::ios::badbit);
@@ -64,7 +100,7 @@ void
 Messenger::barrier() const
 {
 #ifdef PUGS_HAS_MPI
-  MPI_Barrier(MPI_COMM_WORLD);
+  MPI_Barrier(m_comm_world_pugs);
 #endif   // PUGS_HAS_MPI
 }
 
diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp
index 416dbdf1d..692ffed84 100644
--- a/src/utils/Messenger.hpp
+++ b/src/utils/Messenger.hpp
@@ -102,7 +102,7 @@ class Messenger
 
     auto gather_address = (gather.size() > 0) ? &(gather[0]) : nullptr;
 
-    MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, MPI_COMM_WORLD);
+    MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     gather[0] = data;
 #endif   // PUGS_HAS_MPI
@@ -131,7 +131,7 @@ class Messenger
     auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
 
     MPI_Gather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, rank,
-               MPI_COMM_WORLD);
+               m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     copy_to(data_array, gather_array);
 #endif   // PUGS_HAS_MPI
@@ -172,7 +172,7 @@ class Messenger
     auto gather_address    = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
 
     MPI_Gatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
-                mpi_datatype, rank, MPI_COMM_WORLD);
+                mpi_datatype, rank, m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     copy_to(data_array, gather_array);
 #endif   // PUGS_HAS_MPI
@@ -188,7 +188,7 @@ class Messenger
 #ifdef PUGS_HAS_MPI
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
-    MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, MPI_COMM_WORLD);
+    MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     gather[0] = data;
 #endif   // PUGS_HAS_MPI
@@ -217,7 +217,7 @@ class Messenger
     auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
 
     MPI_Allgather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype,
-                  MPI_COMM_WORLD);
+                  m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     copy_to(data_array, gather_array);
 #endif   // PUGS_HAS_MPI
@@ -257,7 +257,7 @@ class Messenger
     auto gather_address    = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
 
     MPI_Allgatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
-                   mpi_datatype, MPI_COMM_WORLD);
+                   mpi_datatype, m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     copy_to(data_array, gather_array);
 #endif   // PUGS_HAS_MPI
@@ -273,7 +273,7 @@ class Messenger
 #ifdef PUGS_HAS_MPI
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
-    MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD);
+    MPI_Bcast(&data, 1, mpi_datatype, root_rank, m_comm_world_pugs);
 #endif   // PUGS_HAS_MPI
   }
 
@@ -289,7 +289,7 @@ class Messenger
     auto array_address = (array.size() > 0) ? &(array[0]) : nullptr;
 
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
-    MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
+    MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, m_comm_world_pugs);
 #endif   // PUGS_HAS_MPI
   }
 
@@ -318,7 +318,7 @@ class Messenger
     auto sent_address = (sent_array.size() > 0) ? &(sent_array[0]) : nullptr;
     auto recv_address = (recv_array.size() > 0) ? &(recv_array[0]) : nullptr;
 
-    MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, MPI_COMM_WORLD);
+    MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, m_comm_world_pugs);
 #else    // PUGS_HAS_MPI
     copy_to(sent_array, recv_array);
 #endif   // PUGS_HAS_MPI
@@ -348,7 +348,7 @@ class Messenger
       const auto sent_array = sent_array_list[i_send];
       if (sent_array.size() > 0) {
         MPI_Request request;
-        MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
+        MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, m_comm_world_pugs, &request);
         request_list.push_back(request);
       }
     }
@@ -357,7 +357,7 @@ class Messenger
       auto recv_array = recv_array_list[i_recv];
       if (recv_array.size() > 0) {
         MPI_Request request;
-        MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
+        MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, m_comm_world_pugs, &request);
         request_list.push_back(request);
       }
     }
@@ -399,6 +399,9 @@ class Messenger
   }
 
  public:
+#ifdef PUGS_HAS_MPI
+  MPI_Comm m_comm_world_pugs = MPI_COMM_NULL;
+#endif   // PUGS_HAS_MPI
   static void create(int& argc, char* argv[]);
   static void destroy();
 
@@ -417,6 +420,15 @@ class Messenger
     return m_rank;
   }
 
+#ifdef PUGS_HAS_MPI
+  PUGS_INLINE
+  const MPI_Comm&
+  comm() const
+  {
+    return m_comm_world_pugs;
+  }
+#endif   // PUGS_HAS_MPI
+
   PUGS_INLINE
   const size_t&
   size() const
@@ -438,7 +450,7 @@ class Messenger
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
     DataType min_data = data;
-    MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);
+    MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, m_comm_world_pugs);
 
     return min_data;
 #else    // PUGS_HAS_MPI
@@ -456,7 +468,7 @@ class Messenger
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
     DataType max_data = data;
-    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD);
+    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, m_comm_world_pugs);
 
     return max_data;
 #else    // PUGS_HAS_MPI
@@ -474,7 +486,7 @@ class Messenger
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
     DataType max_data = data;
-    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD);
+    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, m_comm_world_pugs);
 
     return max_data;
 #else    // PUGS_HAS_MPI
@@ -494,7 +506,7 @@ class Messenger
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
     DataType max_data = data;
-    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);
+    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, m_comm_world_pugs);
 
     return max_data;
 #else    // PUGS_HAS_MPI
@@ -514,7 +526,7 @@ class Messenger
       MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
 
       DataType data_sum = data;
-      MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, MPI_COMM_WORLD);
+      MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, m_comm_world_pugs);
 
       return data_sum;
     } else if (is_trivially_castable<DataType>) {
@@ -523,7 +535,7 @@ class Messenger
       MPI_Datatype mpi_datatype = Messenger::helper::mpiType<InnerDataType>();
       const int size            = sizeof(DataType) / sizeof(InnerDataType);
       DataType data_sum         = data;
-      MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, MPI_COMM_WORLD);
+      MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, m_comm_world_pugs);
 
       return data_sum;
     }
diff --git a/src/utils/Partitioner.cpp b/src/utils/Partitioner.cpp
index 347b07da6..e63a4b48c 100644
--- a/src/utils/Partitioner.cpp
+++ b/src/utils/Partitioner.cpp
@@ -31,7 +31,8 @@ Partitioner::partition(const CRSGraph& graph)
   Array<int> part(0);
 
   MPI_Group world_group;
-  MPI_Comm_group(MPI_COMM_WORLD, &world_group);
+
+  MPI_Comm_group(parallel::Messenger::getInstance().comm(), &world_group);
 
   MPI_Group mesh_group;
   std::vector<int> group_ranks = [&]() {
@@ -49,7 +50,7 @@ Partitioner::partition(const CRSGraph& graph)
   MPI_Group_incl(world_group, group_ranks.size(), &(group_ranks[0]), &mesh_group);
 
   MPI_Comm parmetis_comm;
-  MPI_Comm_create_group(MPI_COMM_WORLD, mesh_group, 1, &parmetis_comm);
+  MPI_Comm_create_group(parallel::Messenger::getInstance().comm(), mesh_group, 1, &parmetis_comm);
 
   int local_number_of_nodes = graph.numberOfNodes();
 
diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp
index 705fbd5e0..569095a81 100644
--- a/src/utils/PugsUtils.cpp
+++ b/src/utils/PugsUtils.cpp
@@ -11,6 +11,9 @@
 #include <utils/SignalManager.hpp>
 #include <utils/pugs_build_info.hpp>
 
+#ifdef PUGS_HAS_PETSC
+#include <petsc.h>
+#endif   // PUGS_HAS_PETSC
 #include <rang.hpp>
 
 #include <Kokkos_Core.hpp>
@@ -131,6 +134,10 @@ initialize(int& argc, char* argv[])
     SignalManager::setPauseForDebug(pause_on_error);
   }
 
+#ifdef PUGS_HAS_PETSC
+  PETSC_COMM_WORLD = parallel::Messenger::getInstance().comm();
+#endif   // PUGS_HAS_PETSC
+
   PETScWrapper::initialize(argc, argv);
   SLEPcWrapper::initialize(argc, argv);
 
-- 
GitLab