diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index f627f016bb21995f89857b12f7f81379071f48e9..6c9a940e9a840e4ca4f16b574c215537e5a0a8f6 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -51,18 +51,3 @@ void Messenger::barrier() const MPI_Barrier(MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI } - -Array<int> Messenger:: -_allGather(int& data) const -{ - Array<int> gather(m_size); - -#ifdef PASTIS_HAS_MPI - MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT, - MPI_COMM_WORLD); -#else // PASTIS_HAS_MPI - gather[0] = data; -#endif // PASTIS_HAS_MPI - - return gather; -} diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index c316a4887ef281a6a614068558864de7ff4adc3f..878e67ed90a942e05cecde900748b323d98aae0f 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -82,15 +82,60 @@ class Messenger int m_rank{0}; int m_size{1}; - Array<int> _allGather(int& data) const; + template <typename DataType> + void _allGather(const DataType& data, + Array<DataType> gather) const + { + static_assert(std::is_arithmetic_v<DataType>); + Assert(gather.size() == m_size); + +#ifdef PASTIS_HAS_MPI + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); + + MPI_Allgather(&data, 1, mpi_datatype, + &(gather[0]), 1, mpi_datatype, + MPI_COMM_WORLD); +#else // PASTIS_HAS_MPI + gather[0] = data; +#endif // PASTIS_HAS_MPI + } + + + + template <template <typename ...SendT> typename SendArrayType, + template <typename ...RecvT> typename RecvArrayType, + typename ...SendT, typename ...RecvT> + void _allGather(const SendArrayType<SendT...>& data_array, + RecvArrayType<RecvT...> gather_array) const + { + Assert(gather_array.size() == data_array.size()*m_size); + + using SendDataType = typename SendArrayType<SendT...>::data_type; + using RecvDataType = typename RecvArrayType<RecvT...>::data_type; + + static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); + static_assert(std::is_arithmetic_v<SendDataType>); + +#ifdef PASTIS_HAS_MPI + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<RecvDataType>(); + + MPI_Allgather(&(data_array[0]), data_array.size(), mpi_datatype, + &(gather_array[0]), data_array.size(), mpi_datatype, + MPI_COMM_WORLD); +#else // PASTIS_HAS_MPI + static_assert(false, "NIY"); +#endif // PASTIS_HAS_MPI + } template <typename DataType> void _broadcast_value(DataType& data, int root_rank) const { -#ifdef PASTIS_HAS_MPI static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); +#ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); @@ -236,14 +281,50 @@ class Messenger template <typename DataType> PASTIS_INLINE - Array<DataType> allGather(const DataType& data) const + Array<DataType> + allGather(const DataType& data) const { - if constexpr(std::is_same<DataType, int>()) { - int int_data = data; - return _allGather(int_data); + static_assert(not std::is_const_v<DataType>); + + Array<DataType> gather_array(m_size); + + if constexpr(std::is_arithmetic_v<DataType>) { + _allGather(data, gather_array); + } else if constexpr(std::is_trivial_v<DataType>) { + using CastType = helper::split_cast_t<DataType>; + + CastArray cast_value_array = cast_value_to<const CastType>::from(data); + CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array); + + _allGather(cast_value_array, cast_gather_array); } else { - static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); } + return gather_array; + } + + template <typename DataType> + PASTIS_INLINE + Array<std::remove_const_t<DataType>> + allGather(const Array<DataType>& array) const + { + using MutableDataType = std::remove_const_t<DataType>; + Array<MutableDataType> gather_array(m_size*array.size()); + + if constexpr(std::is_arithmetic_v<DataType>) { + _allGather(array, gather_array); + } else if constexpr(std::is_trivial_v<DataType>) { + using CastType = helper::split_cast_t<DataType>; + using MutableCastType = helper::split_cast_t<MutableDataType>; + + CastArray cast_array = cast_array_to<CastType>::from(array); + CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); + + _allGather(cast_array, cast_gather_array); + } else { + static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); + } + return gather_array; } template <typename SendDataType> @@ -264,7 +345,7 @@ class Messenger auto recv_cast_array = cast_array_to<CastType>::from(recv_array); _allToAll(send_cast_array, recv_cast_array); } else { - static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); } return recv_array; } @@ -387,11 +468,20 @@ void broadcast(DataType& data, int root_rank) template <typename DataType> PASTIS_INLINE -Array<DataType> allGather(const DataType& data) +Array<DataType> +allGather(const DataType& data) { return messenger().allGather(data); } +template <typename DataType> +PASTIS_INLINE +Array<std::remove_const_t<DataType>> +allGather(const Array<DataType>& array) +{ + return messenger().allGather(array); +} + template <typename DataType> PASTIS_INLINE Array<std::remove_const_t<DataType>>