From 99a79a311798b64e37f335ebc0cb25e6b8296c3c Mon Sep 17 00:00:00 2001 From: chantrait <teddy.chantrait@cea.fr> Date: Wed, 27 Jul 2022 14:16:35 +0200 Subject: [PATCH] Add an internal MPI communicator to allow MPMD coupling --- src/utils/Messenger.cpp | 46 ++++++++++++++++++++++++++++++++++----- src/utils/Messenger.hpp | 46 ++++++++++++++++++++++++--------------- src/utils/Partitioner.cpp | 5 +++-- src/utils/PugsUtils.cpp | 7 ++++++ 4 files changed, 80 insertions(+), 24 deletions(-) diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index e3fff49d0..66f178d3e 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -31,18 +31,54 @@ Messenger::Messenger([[maybe_unused]] int& argc, [[maybe_unused]] char* argv[]) #ifdef PUGS_HAS_MPI MPI_Init(&argc, &argv); - m_rank = []() { + const char* coupled_color = std::getenv("PUGS_COUPLED_COLOR"); + if (coupled_color == NULL) { + m_comm_world_pugs = MPI_COMM_WORLD; + } else { + int color = atoi(coupled_color); + int global_rank; + int global_size; + int key = 0; + + MPI_Comm_rank(MPI_COMM_WORLD, &global_rank); + MPI_Comm_size(MPI_COMM_WORLD, &global_size); + + auto res = MPI_Comm_split(MPI_COMM_WORLD, color, key, &m_comm_world_pugs); + if (res) { + MPI_Abort(MPI_COMM_WORLD, res); + } + + int local_rank; + int local_size; + + MPI_Comm_rank(m_comm_world_pugs, &local_rank); + MPI_Comm_size(m_comm_world_pugs, &local_size); + + std::cout << "----------------- " << rang::fg::green << "pugs coupled info " << rang::fg::reset + << " ----------------------" << '\n'; + std::cout << "Coupling mode activated"; + std::cout << "\n\t Global size: " << global_size; + std::cout << "\n\t Global rank: " << global_rank + 1; + std::cout << "\n\t Coupled color: " << color; + std::cout << "\n\t local size: " << local_size; + std::cout << "\n\t local rank: " << local_rank + 1 << std::endl; + std::cout << "---------------------------------------" << '\n'; + } + + m_rank = [&]() { int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_rank(m_comm_world_pugs, &rank); return rank; }(); - m_size = []() { + m_size = [&]() { int size = 0; - MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_size(m_comm_world_pugs, &size); return size; }(); + std::cout << "pugs process " << m_rank + 1 << "/" << m_size << std::endl; + if (m_rank != 0) { // LCOV_EXCL_START std::cout.setstate(std::ios::badbit); @@ -64,7 +100,7 @@ void Messenger::barrier() const { #ifdef PUGS_HAS_MPI - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(m_comm_world_pugs); #endif // PUGS_HAS_MPI } diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 416dbdf1d..692ffed84 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -102,7 +102,7 @@ class Messenger auto gather_address = (gather.size() > 0) ? &(gather[0]) : nullptr; - MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, MPI_COMM_WORLD); + MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, m_comm_world_pugs); #else // PUGS_HAS_MPI gather[0] = data; #endif // PUGS_HAS_MPI @@ -131,7 +131,7 @@ class Messenger auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Gather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, rank, - MPI_COMM_WORLD); + m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI @@ -172,7 +172,7 @@ class Messenger auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Gatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address, - mpi_datatype, rank, MPI_COMM_WORLD); + mpi_datatype, rank, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI @@ -188,7 +188,7 @@ class Messenger #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); - MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, MPI_COMM_WORLD); + MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI gather[0] = data; #endif // PUGS_HAS_MPI @@ -217,7 +217,7 @@ class Messenger auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Allgather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, - MPI_COMM_WORLD); + m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI @@ -257,7 +257,7 @@ class Messenger auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Allgatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address, - mpi_datatype, MPI_COMM_WORLD); + mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI @@ -273,7 +273,7 @@ class Messenger #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); - MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD); + MPI_Bcast(&data, 1, mpi_datatype, root_rank, m_comm_world_pugs); #endif // PUGS_HAS_MPI } @@ -289,7 +289,7 @@ class Messenger auto array_address = (array.size() > 0) ? &(array[0]) : nullptr; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); - MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD); + MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, m_comm_world_pugs); #endif // PUGS_HAS_MPI } @@ -318,7 +318,7 @@ class Messenger auto sent_address = (sent_array.size() > 0) ? &(sent_array[0]) : nullptr; auto recv_address = (recv_array.size() > 0) ? &(recv_array[0]) : nullptr; - MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, MPI_COMM_WORLD); + MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(sent_array, recv_array); #endif // PUGS_HAS_MPI @@ -348,7 +348,7 @@ class Messenger const auto sent_array = sent_array_list[i_send]; if (sent_array.size() > 0) { MPI_Request request; - MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request); + MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, m_comm_world_pugs, &request); request_list.push_back(request); } } @@ -357,7 +357,7 @@ class Messenger auto recv_array = recv_array_list[i_recv]; if (recv_array.size() > 0) { MPI_Request request; - MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request); + MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, m_comm_world_pugs, &request); request_list.push_back(request); } } @@ -399,6 +399,9 @@ class Messenger } public: +#ifdef PUGS_HAS_MPI + MPI_Comm m_comm_world_pugs = MPI_COMM_NULL; +#endif // PUGS_HAS_MPI static void create(int& argc, char* argv[]); static void destroy(); @@ -417,6 +420,15 @@ class Messenger return m_rank; } +#ifdef PUGS_HAS_MPI + PUGS_INLINE + const MPI_Comm& + comm() const + { + return m_comm_world_pugs; + } +#endif // PUGS_HAS_MPI + PUGS_INLINE const size_t& size() const @@ -438,7 +450,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType min_data = data; - MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, m_comm_world_pugs); return min_data; #else // PUGS_HAS_MPI @@ -456,7 +468,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; - MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD); + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI @@ -474,7 +486,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; - MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD); + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI @@ -494,7 +506,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; - MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI @@ -514,7 +526,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType data_sum = data; - MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, m_comm_world_pugs); return data_sum; } else if (is_trivially_castable<DataType>) { @@ -523,7 +535,7 @@ class Messenger MPI_Datatype mpi_datatype = Messenger::helper::mpiType<InnerDataType>(); const int size = sizeof(DataType) / sizeof(InnerDataType); DataType data_sum = data; - MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, m_comm_world_pugs); return data_sum; } diff --git a/src/utils/Partitioner.cpp b/src/utils/Partitioner.cpp index 347b07da6..e63a4b48c 100644 --- a/src/utils/Partitioner.cpp +++ b/src/utils/Partitioner.cpp @@ -31,7 +31,8 @@ Partitioner::partition(const CRSGraph& graph) Array<int> part(0); MPI_Group world_group; - MPI_Comm_group(MPI_COMM_WORLD, &world_group); + + MPI_Comm_group(parallel::Messenger::getInstance().comm(), &world_group); MPI_Group mesh_group; std::vector<int> group_ranks = [&]() { @@ -49,7 +50,7 @@ Partitioner::partition(const CRSGraph& graph) MPI_Group_incl(world_group, group_ranks.size(), &(group_ranks[0]), &mesh_group); MPI_Comm parmetis_comm; - MPI_Comm_create_group(MPI_COMM_WORLD, mesh_group, 1, &parmetis_comm); + MPI_Comm_create_group(parallel::Messenger::getInstance().comm(), mesh_group, 1, &parmetis_comm); int local_number_of_nodes = graph.numberOfNodes(); diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index 705fbd5e0..569095a81 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -11,6 +11,9 @@ #include <utils/SignalManager.hpp> #include <utils/pugs_build_info.hpp> +#ifdef PUGS_HAS_PETSC +#include <petsc.h> +#endif // PUGS_HAS_PETSC #include <rang.hpp> #include <Kokkos_Core.hpp> @@ -131,6 +134,10 @@ initialize(int& argc, char* argv[]) SignalManager::setPauseForDebug(pause_on_error); } +#ifdef PUGS_HAS_PETSC + PETSC_COMM_WORLD = parallel::Messenger::getInstance().comm(); +#endif // PUGS_HAS_PETSC + PETScWrapper::initialize(argc, argv); SLEPcWrapper::initialize(argc, argv); -- GitLab