diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index a605d3e828e138eb0c9f279f762d013a84c1ab81..ee6cfb5de3355fdaf6f33b92e94d52d2346eec56 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -113,3 +113,49 @@ _broadcast(Array<int>& array, int root_rank) const #endif // PASTIS_HAS_MPI return array; } + + +void Messenger:: +_exchange(const std::vector<Array<const int>>& sent_array_list, + std::vector<Array<int>>& recv_array_list) const +{ +#ifdef PASTIS_HAS_MPI + std::vector<MPI_Request> request_list; + + for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { + const Array<const int> sent_array = sent_array_list[i_send]; + if (sent_array.size()>0) { + MPI_Request request; + MPI_Isend(&(sent_array[0]), sent_array.size(), MPI_INT, i_send, 0, MPI_COMM_WORLD, &request); + request_list.push_back(request); + } + } + + for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { + Array<int> recv_array = recv_array_list[i_recv]; + if (recv_array.size()>0) { + MPI_Request request; + MPI_Irecv(&(recv_array[0]), recv_array.size(), MPI_INT, i_recv, 0, MPI_COMM_WORLD, &request); + request_list.push_back(request); + } + } + + std::vector<MPI_Status> status_list(request_list.size()); + if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { + std::cerr << "Communication error!\n"; + std::exit(1); + } + +#else // PASTIS_HAS_MPI + std::cerr << "NIY\n"; + std::exit(1); +#endif // PASTIS_HAS_MPI +} + +void Messenger:: +_exchange(const std::vector<Array<int>>& sent_array_list, + std::vector<Array<int>>& recv_array_list) const +{ + std::cerr << "NIY\n"; + std::exit(1); +} diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 90c4248af77b917b86ede50973a1ea7900453be2..796b08894ebdfc798cd5f0e70b2ab0bdf730fae9 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -21,6 +21,10 @@ class Messenger Array<int> _broadcast(Array<int>& array, int root_rank) const; Array<int> _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const; + void _exchange(const std::vector<Array<const int>>& sent_array_list, + std::vector<Array<int>>& recv_array_list) const; + void _exchange(const std::vector<Array<int>>& sent_array_list, + std::vector<Array<int>>& recv_array_list) const; public: static void create(int& argc, char* argv[]); static void destroy(); @@ -95,6 +99,24 @@ class Messenger } } + template <typename SendDataType, + typename RecvDataType> + PASTIS_INLINE + void exchange(const std::vector<Array<SendDataType>>& sent_array_list, + std::vector<Array<RecvDataType>>& recv_array_list) const + { + static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>, + "send and receive data type do not match"); + static_assert(not std::is_const_v<RecvDataType>, + "receive data type cannot be const"); + + if constexpr(std::is_same<RecvDataType, int>()) { + _exchange(sent_array_list, recv_array_list); + } else { + static_assert(std::is_same<RecvDataType, int>(), "unexpected type of data"); + } + } + Messenger(const Messenger&) = delete; ~Messenger(); }; @@ -153,4 +175,18 @@ Array<DataType> broadcast(const Array<DataType>& array, int root_rank) return messenger().broadcast(array, root_rank); } +template <typename SendDataType, + typename RecvDataType> +PASTIS_INLINE +void exchange(const std::vector<Array<SendDataType>>& sent_array_list, + std::vector<Array<RecvDataType>>& recv_array_list) +{ + static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>, + "send and receive data type do not match"); + static_assert(not std::is_const_v<RecvDataType>, + "receive data type cannot be const"); + + messenger().exchange(sent_array_list, recv_array_list); +} + #endif // MESSENGER_HPP