From 67f8366526ad113eb5558d591280c1e093fcb558 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 17 Oct 2018 19:05:39 +0200 Subject: [PATCH] Write broadcast of any trivial type values Remind that TinyVector, TinyMatrix, CellType,... are trivial types --- src/utils/Messenger.cpp | 9 --------- src/utils/Messenger.hpp | 41 +++++++++++++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index 936c81e20..f627f016b 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -66,12 +66,3 @@ _allGather(int& data) const return gather; } - -int Messenger:: -_broadcast_value(int& data, int root_rank) const -{ -#ifdef PASTIS_HAS_MPI - MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); -#endif // PASTIS_HAS_MPI - return data; -} diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 0c56c2631..c316a4887 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -84,13 +84,26 @@ class Messenger Array<int> _allGather(int& data) const; - int _broadcast_value(int& data, int root_rank) const; + template <typename DataType> + void _broadcast_value(DataType& data, int root_rank) const + { +#ifdef PASTIS_HAS_MPI + static_assert(not std::is_const_v<DataType>); + static_assert(std::is_arithmetic_v<DataType>); + + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); + + MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI + } template <typename ArrayType> void _broadcast_array(ArrayType& array, int root_rank) const { using DataType = typename ArrayType::data_type; static_assert(not std::is_const_v<DataType>); + static_assert(std::is_arithmetic_v<DataType>); #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype @@ -260,10 +273,26 @@ class Messenger PASTIS_INLINE void broadcast(DataType& data, int root_rank) const { - if constexpr(std::is_same<DataType, int>()) { - return _broadcast_value(data, root_rank); + static_assert(not std::is_const_v<DataType>, + "cannot broadcast const data"); + if constexpr(std::is_arithmetic_v<DataType>) { + _broadcast_value(data, root_rank); + } else if constexpr(std::is_trivial_v<DataType>) { + using CastType = helper::split_cast_t<DataType>; + if constexpr(sizeof(CastType) == sizeof(DataType)) { + CastType& cast_data = reinterpret_cast<CastType&>(data); + _broadcast_value(cast_data, root_rank); + } else { +#ifdef PASTIS_HAS_MPI + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<CastType>(); + MPI_Bcast(reinterpret_cast<CastType*>(&data), sizeof(DataType)/sizeof(CastType), + mpi_datatype, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI + } } else { - static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + static_assert(std::is_trivial_v<DataType>, + "unexpected non trivial type of data"); } } @@ -351,9 +380,9 @@ void barrier() template <typename DataType> PASTIS_INLINE -DataType broadcast(const DataType& data, int root_rank) +void broadcast(DataType& data, int root_rank) { - return messenger().broadcast(data, root_rank); + messenger().broadcast(data, root_rank); } template <typename DataType> -- GitLab