Skip to content
Snippets Groups Projects
Commit cc9a3487 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Improve internal generic _allToAll and _exchange methods

One use variadic templates to ensure that generic array types are used
parent b27836c9
No related branches found
No related tags found
1 merge request!11Feature/mpi
...@@ -99,14 +99,15 @@ class Messenger ...@@ -99,14 +99,15 @@ class Messenger
#endif // PASTIS_HAS_MPI #endif // PASTIS_HAS_MPI
} }
template <template <typename ...SendT> typename SendArrayType,
template <typename SendArrayType, template <typename ...RecvT> typename RecvArrayType,
typename RecvArrayType> typename ...SendT, typename ...RecvT>
RecvArrayType _allToAll(const SendArrayType& sent_array, RecvArrayType& recv_array) const RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array,
RecvArrayType<RecvT...>& recv_array) const
{ {
#ifdef PASTIS_HAS_MPI #ifdef PASTIS_HAS_MPI
using SendDataType = typename SendArrayType::data_type; using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>); static_assert(std::is_arithmetic_v<SendDataType>);
...@@ -128,13 +129,14 @@ class Messenger ...@@ -128,13 +129,14 @@ class Messenger
return recv_array; return recv_array;
} }
template <typename SendArrayType, template <template <typename ...SendT> typename SendArrayType,
typename RecvArrayType> template <typename ...RecvT> typename RecvArrayType,
void _exchange(const std::vector<SendArrayType>& sent_array_list, typename ...SendT, typename ...RecvT>
std::vector<RecvArrayType>& recv_array_list) const void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list,
std::vector<RecvArrayType<RecvT...>>& recv_array_list) const
{ {
using SendDataType = typename SendArrayType::data_type; using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>); static_assert(std::is_arithmetic_v<SendDataType>);
...@@ -146,7 +148,7 @@ class Messenger ...@@ -146,7 +148,7 @@ class Messenger
= Messenger::helper::mpiType<SendDataType>(); = Messenger::helper::mpiType<SendDataType>();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const SendArrayType sent_array = sent_array_list[i_send]; const auto sent_array = sent_array_list[i_send];
if (sent_array.size()>0) { if (sent_array.size()>0) {
MPI_Request request; MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request); MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
...@@ -155,7 +157,7 @@ class Messenger ...@@ -155,7 +157,7 @@ class Messenger
} }
for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
RecvArrayType recv_array = recv_array_list[i_recv]; auto recv_array = recv_array_list[i_recv];
if (recv_array.size()>0) { if (recv_array.size()>0) {
MPI_Request request; MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request); MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment