diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index ff2674da949d3ff92696309ad119d8aeecfaf57c..bc2ed9498dfb85601113f134a92a142389f32734 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -99,14 +99,15 @@ class Messenger #endif // PASTIS_HAS_MPI } - - template <typename SendArrayType, - typename RecvArrayType> - RecvArrayType _allToAll(const SendArrayType& sent_array, RecvArrayType& recv_array) const + template <template <typename ...SendT> typename SendArrayType, + template <typename ...RecvT> typename RecvArrayType, + typename ...SendT, typename ...RecvT> + RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array, + RecvArrayType<RecvT...>& recv_array) const { #ifdef PASTIS_HAS_MPI - using SendDataType = typename SendArrayType::data_type; - using RecvDataType = typename RecvArrayType::data_type; + using SendDataType = typename SendArrayType<SendT...>::data_type; + using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); @@ -128,13 +129,14 @@ class Messenger return recv_array; } - template <typename SendArrayType, - typename RecvArrayType> - void _exchange(const std::vector<SendArrayType>& sent_array_list, - std::vector<RecvArrayType>& recv_array_list) const + template <template <typename ...SendT> typename SendArrayType, + template <typename ...RecvT> typename RecvArrayType, + typename ...SendT, typename ...RecvT> + void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list, + std::vector<RecvArrayType<RecvT...>>& recv_array_list) const { - using SendDataType = typename SendArrayType::data_type; - using RecvDataType = typename RecvArrayType::data_type; + using SendDataType = typename SendArrayType<SendT...>::data_type; + using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); @@ -146,7 +148,7 @@ class Messenger = Messenger::helper::mpiType<SendDataType>(); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { - const SendArrayType sent_array = sent_array_list[i_send]; + const auto sent_array = sent_array_list[i_send]; if (sent_array.size()>0) { MPI_Request request; MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request); @@ -155,7 +157,7 @@ class Messenger } for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { - RecvArrayType recv_array = recv_array_list[i_recv]; + auto recv_array = recv_array_list[i_recv]; if (recv_array.size()>0) { MPI_Request request; MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);