diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index 657bdb5958f73963eb4fdeeb5f83ffd7fde82de0..fec10ce853654bc64400371dc2398e19443c5558 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -67,23 +67,6 @@ _allGather(int& data) const return gather; } -Array<int> Messenger:: -_allToAll(const Array<int>& sent_array, Array<int>& recv_array) const -{ -#ifdef PASTIS_HAS_MPI - Assert(sent_array.size() == m_size); - Assert(recv_array.size() == m_size); - - MPI_Alltoall(&(sent_array[0]), 1, MPI_INT, - &(recv_array[0]), 1, MPI_INT, - MPI_COMM_WORLD); -#else // PASTIS_HAS_MPI - recv_array = copy(sent_array); -#endif // PASTIS_HAS_MPI - return recv_array; -} - - int Messenger:: _broadcast(int& data, int root_rank) const { diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index f1c74ffb87df83bc804cf4692d67b15c027cc585..bc49cf63f43213260c63a24e2fa48339bb68dd58 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -90,7 +90,33 @@ class Messenger int _broadcast(int& data, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const; - Array<int> _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const; + + template <typename SendArrayType, + typename RecvArrayType> + RecvArrayType _allToAll(const SendArrayType& sent_array, RecvArrayType& recv_array) const + { +#ifdef PASTIS_HAS_MPI + using SendDataType = typename SendArrayType::data_type; + using RecvDataType = typename RecvArrayType::data_type; + + static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); + static_assert(std::is_arithmetic_v<SendDataType>); + + Assert((sent_array.size() % m_size) == 0); + Assert(recv_array.size() == recv_array.size()); + + const int count = sent_array.size()/m_size; + + MPI_Datatype mpi_datatype = Messenger::Helper::mpiType<SendDataType>(); + + MPI_Alltoall(&(sent_array[0]), count, mpi_datatype, + &(recv_array[0]), count, mpi_datatype, + MPI_COMM_WORLD); +#else // PASTIS_HAS_MPI + recv_array = copy(sent_array); +#endif // PASTIS_HAS_MPI + return recv_array; + } template <typename SendArrayType, typename RecvArrayType> @@ -106,13 +132,13 @@ class Messenger #ifdef PASTIS_HAS_MPI std::vector<MPI_Request> request_list; - MPI_Datatype type = Messenger::Helper::mpiType<SendDataType>(); + MPI_Datatype mpi_datatype = Messenger::Helper::mpiType<SendDataType>(); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { const SendArrayType sent_array = sent_array_list[i_send]; if (sent_array.size()>0) { MPI_Request request; - MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); + MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } @@ -121,7 +147,7 @@ class Messenger RecvArrayType recv_array = recv_array_list[i_recv]; if (recv_array.size()>0) { MPI_Request request; - MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); + MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } @@ -194,17 +220,27 @@ class Messenger } } - template <typename DataType> + template <typename SendDataType> PASTIS_INLINE - Array<DataType> allToAll(const Array<DataType>& sent_array) const + Array<std::remove_const_t<SendDataType>> + allToAll(const Array<SendDataType>& sent_array) const { - Assert(sent_array.size() == static_cast<size_t>(m_size)); - if constexpr(std::is_same<DataType, int>()) { - Array<int> recv_array(m_size); - return _allToAll(sent_array, recv_array); + Assert((sent_array.size() % m_size) == 0); + using DataType = std::remove_const_t<SendDataType>; + + Array<DataType> recv_array(sent_array.size()); + if constexpr(std::is_arithmetic_v<DataType>) { + _allToAll(sent_array, recv_array); + } else if constexpr(std::is_trivial_v<DataType>) { + using CastType = Helper::split_cast_t<DataType>; + + auto send_cast_array = cast_array_to<const CastType>::from(sent_array); + auto recv_cast_array = cast_array_to<CastType>::from(recv_array); + _allToAll(send_cast_array, recv_cast_array); } else { static_assert(std::is_same<DataType, int>(), "unexpected type of data"); } + return recv_array; } template <typename DataType> @@ -300,7 +336,8 @@ Array<DataType> allGather(const DataType& data) template <typename DataType> PASTIS_INLINE -Array<DataType> allToAll(const Array<DataType>& array) +Array<std::remove_const_t<DataType>> +allToAll(const Array<DataType>& array) { return messenger().allToAll(array); }