#include <Messenger.hpp> #include <PastisOStream.hpp> #include <pastis_config.hpp> #ifdef PASTIS_HAS_MPI #include <mpi.h> #endif // PASTIS_HAS_MPI Messenger* Messenger::m_instance = nullptr; void Messenger::create(int& argc, char* argv[]) { if (Messenger::m_instance == nullptr) { Messenger::m_instance = new Messenger(argc, argv); } else { std::cerr << "Messenger already created\n"; std::exit(1); } } void Messenger::destroy() { // One allows multiple destruction to handle unexpected code exit if (Messenger::m_instance != nullptr) { delete Messenger::m_instance; Messenger::m_instance = nullptr; } } Messenger:: Messenger(int& argc, char* argv[]) { #ifdef PASTIS_HAS_MPI MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &m_rank); MPI_Comm_size(MPI_COMM_WORLD, &m_size); if (m_rank != 0) { pout.setOutput(null_stream); perr.setOutput(null_stream); } #endif // PASTIS_HAS_MPI } Messenger:: ~Messenger() { #ifdef PASTIS_HAS_MPI MPI_Finalize(); #endif // PASTIS_HAS_MPI } void Messenger::barrier() const { #ifdef PASTIS_HAS_MPI MPI_Barrier(MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI } Array<int> Messenger:: _allGather(int& data) const { Array<int> gather(m_size); #ifdef PASTIS_HAS_MPI MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT, MPI_COMM_WORLD); #else // PASTIS_HAS_MPI gather[0] = data; #endif // PASTIS_HAS_MPI return gather; } Array<int> Messenger:: _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const { #ifdef PASTIS_HAS_MPI Assert(sent_array.size() == m_size); Assert(recv_array.size() == m_size); MPI_Alltoall(&(sent_array[0]), 1, MPI_INT, &(recv_array[0]), 1, MPI_INT, MPI_COMM_WORLD); #else // PASTIS_HAS_MPI recv_array = copy(sent_array); #endif // PASTIS_HAS_MPI return recv_array; } int Messenger:: _broadcast(int& data, int root_rank) const { #ifdef PASTIS_HAS_MPI MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI return data; } Array<int> Messenger:: _broadcast(Array<int>& array, int root_rank) const { #ifdef PASTIS_HAS_MPI int size = array.size(); _broadcast(size, root_rank); if (commRank() != root_rank) { array = Array<int>(size); } MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI return array; } template <typename DataType> void Messenger:: _exchange(const std::vector<Array<DataType>>& sent_array_list, std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const { #ifdef PASTIS_HAS_MPI std::vector<MPI_Request> request_list; MPI_Datatype type = [&] () -> MPI_Datatype { if constexpr (std::is_same_v<int,std::remove_const_t<DataType>>) { return MPI_INT; } else if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) { return MPI_SHORT; } } (); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { const Array<DataType> sent_array = sent_array_list[i_send]; if (sent_array.size()>0) { MPI_Request request; MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv]; if (recv_array.size()>0) { MPI_Request request; MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } std::vector<MPI_Status> status_list(request_list.size()); if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { std::cerr << "Communication error!\n"; std::exit(1); } #else // PASTIS_HAS_MPI std::cerr << "NIY\n"; std::exit(1); #endif // PASTIS_HAS_MPI } template void Messenger::_exchange(const std::vector<Array<int>>& sent_array_list, std::vector<Array<int>>& recv_array_list) const; template void Messenger::_exchange(const std::vector<Array<const int>>& sent_array_list, std::vector<Array<int>>& recv_array_list) const; template void Messenger::_exchange(const std::vector<Array<CellType>>& sent_array_list, std::vector<Array<CellType>>& recv_array_list) const; template void Messenger::_exchange(const std::vector<Array<const CellType>>& sent_array_list, std::vector<Array<CellType>>& recv_array_list) const;