Skip to content
Snippets Groups Projects
Commit 99a79a31 authored by chantrait's avatar chantrait
Browse files

Add an internal MPI communicator to allow MPMD coupling

parent 743a0d8f
No related branches found
No related tags found
1 merge request!172Reproducible summation of floating point arrays
......@@ -31,18 +31,54 @@ Messenger::Messenger([[maybe_unused]] int& argc, [[maybe_unused]] char* argv[])
#ifdef PUGS_HAS_MPI
MPI_Init(&argc, &argv);
m_rank = []() {
const char* coupled_color = std::getenv("PUGS_COUPLED_COLOR");
if (coupled_color == NULL) {
m_comm_world_pugs = MPI_COMM_WORLD;
} else {
int color = atoi(coupled_color);
int global_rank;
int global_size;
int key = 0;
MPI_Comm_rank(MPI_COMM_WORLD, &global_rank);
MPI_Comm_size(MPI_COMM_WORLD, &global_size);
auto res = MPI_Comm_split(MPI_COMM_WORLD, color, key, &m_comm_world_pugs);
if (res) {
MPI_Abort(MPI_COMM_WORLD, res);
}
int local_rank;
int local_size;
MPI_Comm_rank(m_comm_world_pugs, &local_rank);
MPI_Comm_size(m_comm_world_pugs, &local_size);
std::cout << "----------------- " << rang::fg::green << "pugs coupled info " << rang::fg::reset
<< " ----------------------" << '\n';
std::cout << "Coupling mode activated";
std::cout << "\n\t Global size: " << global_size;
std::cout << "\n\t Global rank: " << global_rank + 1;
std::cout << "\n\t Coupled color: " << color;
std::cout << "\n\t local size: " << local_size;
std::cout << "\n\t local rank: " << local_rank + 1 << std::endl;
std::cout << "---------------------------------------" << '\n';
}
m_rank = [&]() {
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_rank(m_comm_world_pugs, &rank);
return rank;
}();
m_size = []() {
m_size = [&]() {
int size = 0;
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Comm_size(m_comm_world_pugs, &size);
return size;
}();
std::cout << "pugs process " << m_rank + 1 << "/" << m_size << std::endl;
if (m_rank != 0) {
// LCOV_EXCL_START
std::cout.setstate(std::ios::badbit);
......@@ -64,7 +100,7 @@ void
Messenger::barrier() const
{
#ifdef PUGS_HAS_MPI
MPI_Barrier(MPI_COMM_WORLD);
MPI_Barrier(m_comm_world_pugs);
#endif // PUGS_HAS_MPI
}
......
......@@ -102,7 +102,7 @@ class Messenger
auto gather_address = (gather.size() > 0) ? &(gather[0]) : nullptr;
MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, MPI_COMM_WORLD);
MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, m_comm_world_pugs);
#else // PUGS_HAS_MPI
gather[0] = data;
#endif // PUGS_HAS_MPI
......@@ -131,7 +131,7 @@ class Messenger
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Gather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, rank,
MPI_COMM_WORLD);
m_comm_world_pugs);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
......@@ -172,7 +172,7 @@ class Messenger
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Gatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
mpi_datatype, rank, MPI_COMM_WORLD);
mpi_datatype, rank, m_comm_world_pugs);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
......@@ -188,7 +188,7 @@ class Messenger
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, MPI_COMM_WORLD);
MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, m_comm_world_pugs);
#else // PUGS_HAS_MPI
gather[0] = data;
#endif // PUGS_HAS_MPI
......@@ -217,7 +217,7 @@ class Messenger
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Allgather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype,
MPI_COMM_WORLD);
m_comm_world_pugs);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
......@@ -257,7 +257,7 @@ class Messenger
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Allgatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
mpi_datatype, MPI_COMM_WORLD);
mpi_datatype, m_comm_world_pugs);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
......@@ -273,7 +273,7 @@ class Messenger
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD);
MPI_Bcast(&data, 1, mpi_datatype, root_rank, m_comm_world_pugs);
#endif // PUGS_HAS_MPI
}
......@@ -289,7 +289,7 @@ class Messenger
auto array_address = (array.size() > 0) ? &(array[0]) : nullptr;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, m_comm_world_pugs);
#endif // PUGS_HAS_MPI
}
......@@ -318,7 +318,7 @@ class Messenger
auto sent_address = (sent_array.size() > 0) ? &(sent_array[0]) : nullptr;
auto recv_address = (recv_array.size() > 0) ? &(recv_array[0]) : nullptr;
MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, MPI_COMM_WORLD);
MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, m_comm_world_pugs);
#else // PUGS_HAS_MPI
copy_to(sent_array, recv_array);
#endif // PUGS_HAS_MPI
......@@ -348,7 +348,7 @@ class Messenger
const auto sent_array = sent_array_list[i_send];
if (sent_array.size() > 0) {
MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, m_comm_world_pugs, &request);
request_list.push_back(request);
}
}
......@@ -357,7 +357,7 @@ class Messenger
auto recv_array = recv_array_list[i_recv];
if (recv_array.size() > 0) {
MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, m_comm_world_pugs, &request);
request_list.push_back(request);
}
}
......@@ -399,6 +399,9 @@ class Messenger
}
public:
#ifdef PUGS_HAS_MPI
MPI_Comm m_comm_world_pugs = MPI_COMM_NULL;
#endif // PUGS_HAS_MPI
static void create(int& argc, char* argv[]);
static void destroy();
......@@ -417,6 +420,15 @@ class Messenger
return m_rank;
}
#ifdef PUGS_HAS_MPI
PUGS_INLINE
const MPI_Comm&
comm() const
{
return m_comm_world_pugs;
}
#endif // PUGS_HAS_MPI
PUGS_INLINE
const size_t&
size() const
......@@ -438,7 +450,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType min_data = data;
MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);
MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, m_comm_world_pugs);
return min_data;
#else // PUGS_HAS_MPI
......@@ -456,7 +468,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD);
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, m_comm_world_pugs);
return max_data;
#else // PUGS_HAS_MPI
......@@ -474,7 +486,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD);
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, m_comm_world_pugs);
return max_data;
#else // PUGS_HAS_MPI
......@@ -494,7 +506,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, m_comm_world_pugs);
return max_data;
#else // PUGS_HAS_MPI
......@@ -514,7 +526,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType data_sum = data;
MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, MPI_COMM_WORLD);
MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, m_comm_world_pugs);
return data_sum;
} else if (is_trivially_castable<DataType>) {
......@@ -523,7 +535,7 @@ class Messenger
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<InnerDataType>();
const int size = sizeof(DataType) / sizeof(InnerDataType);
DataType data_sum = data;
MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, MPI_COMM_WORLD);
MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, m_comm_world_pugs);
return data_sum;
}
......
......@@ -31,7 +31,8 @@ Partitioner::partition(const CRSGraph& graph)
Array<int> part(0);
MPI_Group world_group;
MPI_Comm_group(MPI_COMM_WORLD, &world_group);
MPI_Comm_group(parallel::Messenger::getInstance().comm(), &world_group);
MPI_Group mesh_group;
std::vector<int> group_ranks = [&]() {
......@@ -49,7 +50,7 @@ Partitioner::partition(const CRSGraph& graph)
MPI_Group_incl(world_group, group_ranks.size(), &(group_ranks[0]), &mesh_group);
MPI_Comm parmetis_comm;
MPI_Comm_create_group(MPI_COMM_WORLD, mesh_group, 1, &parmetis_comm);
MPI_Comm_create_group(parallel::Messenger::getInstance().comm(), mesh_group, 1, &parmetis_comm);
int local_number_of_nodes = graph.numberOfNodes();
......
......@@ -11,6 +11,9 @@
#include <utils/SignalManager.hpp>
#include <utils/pugs_build_info.hpp>
#ifdef PUGS_HAS_PETSC
#include <petsc.h>
#endif // PUGS_HAS_PETSC
#include <rang.hpp>
#include <Kokkos_Core.hpp>
......@@ -131,6 +134,10 @@ initialize(int& argc, char* argv[])
SignalManager::setPauseForDebug(pause_on_error);
}
#ifdef PUGS_HAS_PETSC
PETSC_COMM_WORLD = parallel::Messenger::getInstance().comm();
#endif // PUGS_HAS_PETSC
PETScWrapper::initialize(argc, argv);
SLEPcWrapper::initialize(argc, argv);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment