From b27836c9165a6626c39ede98ed70dab83581bab4 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 16 Oct 2018 19:12:59 +0200 Subject: [PATCH] Rewrite array broadcast in a generic way (was only defined for int) --- src/utils/Messenger.cpp | 17 +------------ src/utils/Messenger.hpp | 53 +++++++++++++++++++++++++++++++++-------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index fec10ce85..936c81e20 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -68,25 +68,10 @@ _allGather(int& data) const } int Messenger:: -_broadcast(int& data, int root_rank) const +_broadcast_value(int& data, int root_rank) const { #ifdef PASTIS_HAS_MPI MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI return data; } - - -Array<int> Messenger:: -_broadcast(Array<int>& array, int root_rank) const -{ -#ifdef PASTIS_HAS_MPI - int size = array.size(); - _broadcast(size, root_rank); - if (commRank() != root_rank) { - array = Array<int>(size); - } - MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD); -#endif // PASTIS_HAS_MPI - return array; -} diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 2af951e75..ff2674da9 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -84,8 +84,21 @@ class Messenger Array<int> _allGather(int& data) const; - int _broadcast(int& data, int root_rank) const; - Array<int> _broadcast(Array<int>& array, int root_rank) const; + int _broadcast_value(int& data, int root_rank) const; + + template <typename ArrayType> + void _broadcast_array(ArrayType& array, int root_rank) const + { + using DataType = typename ArrayType::data_type; + static_assert(not std::is_const_v<DataType>); + +#ifdef PASTIS_HAS_MPI + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); + MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI + } + template <typename SendArrayType, typename RecvArrayType> @@ -103,7 +116,8 @@ class Messenger const int count = sent_array.size()/m_size; - MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<SendDataType>(); MPI_Alltoall(&(sent_array[0]), count, mpi_datatype, &(recv_array[0]), count, mpi_datatype, @@ -128,7 +142,8 @@ class Messenger #ifdef PASTIS_HAS_MPI std::vector<MPI_Request> request_list; - MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<SendDataType>(); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { const SendArrayType sent_array = sent_array_list[i_send]; @@ -245,7 +260,7 @@ class Messenger { if constexpr(std::is_same<DataType, int>()) { int int_data = data; - return _broadcast(int_data, root_rank); + return _broadcast_value(int_data, root_rank); } else { static_assert(std::is_same<DataType, int>(), "unexpected type of data"); } @@ -255,12 +270,30 @@ class Messenger PASTIS_INLINE Array<DataType> broadcast(const Array<DataType>& array, int root_rank) const { - if constexpr(std::is_same<DataType, int>()) { - Array<int> int_array = array; - return _broadcast(int_array, root_rank); - } else { - static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + static_assert(not std::is_const_v<DataType>, + "cannot broadcast array of const"); + if constexpr(std::is_arithmetic_v<DataType>) { + int size = array.size(); + _broadcast_value(size, root_rank); + if (m_rank != root_rank) { + array = Array<DataType>(size); + } + _broadcast_array(array, root_rank); + } else if constexpr(std::is_trivial_v<DataType>) { + int size = array.size(); + _broadcast_value(size, root_rank); + if (m_rank != root_rank) { + array = Array<DataType>(size); + } + + using CastType = helper::split_cast_t<DataType>; + auto cast_array = cast_array_to<CastType>::from(array); + _broadcast_array(cast_array, root_rank); + } else{ + static_assert(std::is_trivial_v<DataType>, + "unexpected non trivial type of data"); } + return array; } template <typename SendDataType, -- GitLab