diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index d631fcf3c1f612d154423cc9595248c153987059..44d353e2b97857aff48b0d27a2907dc1d636c8f7 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -51,7 +51,6 @@ Messenger:: #endif // PASTIS_HAS_MPI } - void Messenger::barrier() const { #ifdef PASTIS_HAS_MPI @@ -59,11 +58,19 @@ void Messenger::barrier() const #endif // PASTIS_HAS_MPI } +int Messenger:: +_broadcast(int& data, int root_rank) const +{ + MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); + return data; +} + + Array<int> Messenger:: _broadcast(Array<int>& array, int root_rank) const { int size = array.size(); - MPI_Bcast(&size, 1, MPI_INT, root_rank, MPI_COMM_WORLD); + _broadcast(size, root_rank); if (commRank() != root_rank) { array = Array<int>(size); } diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index a0bd59328cbfa9f3486972072fc615eb9f5fd71c..517be0d1c311fd11475fe0f1879aea7af753d8e6 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -15,6 +15,7 @@ class Messenger int m_rank{0}; int m_size{1}; + int _broadcast(int& data, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const; public: @@ -42,6 +43,18 @@ class Messenger void barrier() const; + template <typename DataType> + PASTIS_INLINE + DataType broadcast(const DataType& data, int root_rank) const + { + if constexpr(std::is_same<DataType, int>()) { + int int_data = data; + return _broadcast(int_data, root_rank); + } else { + static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + } + } + template <typename DataType> PASTIS_INLINE Array<DataType> broadcast(const Array<DataType>& array, int root_rank) const @@ -91,4 +104,11 @@ Array<DataType> broadcast(const Array<DataType>& array, int root_rank) return messenger().broadcast(array, root_rank); } +template <typename DataType> +PASTIS_INLINE +DataType broadcast(const DataType& data, int root_rank) +{ + return messenger().broadcast(data, root_rank); +} + #endif // MESSENGER_HPP