diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index 44d353e2b97857aff48b0d27a2907dc1d636c8f7..5817a2a207a97ab19a8047801ca87fe6521ec591 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -58,10 +58,29 @@ void Messenger::barrier() const #endif // PASTIS_HAS_MPI } +Array<int> Messenger:: +_allGather(int& data) const +{ + Array<int> gather(m_size); + +#ifdef PASTIS_HAS_MPI + MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT, + MPI_COMM_WORLD); +#else // PASTIS_HAS_MPI + gather[0] = data; +#endif // PASTIS_HAS_MPI + + return gather; +} + + + int Messenger:: _broadcast(int& data, int root_rank) const { +#ifdef PASTIS_HAS_MPI MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI return data; } @@ -69,11 +88,13 @@ _broadcast(int& data, int root_rank) const Array<int> Messenger:: _broadcast(Array<int>& array, int root_rank) const { +#ifdef PASTIS_HAS_MPI int size = array.size(); _broadcast(size, root_rank); if (commRank() != root_rank) { array = Array<int>(size); } MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI return array; } diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 517be0d1c311fd11475fe0f1879aea7af753d8e6..6728b057fe0e19d74041373f539d9d017e05202e 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -15,6 +15,8 @@ class Messenger int m_rank{0}; int m_size{1}; + Array<int> _allGather(int& data) const; + int _broadcast(int& data, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const; @@ -43,6 +45,18 @@ class Messenger void barrier() const; + template <typename DataType> + PASTIS_INLINE + Array<DataType> allGather(const DataType& data) const + { + if constexpr(std::is_same<DataType, int>()) { + int int_data = data; + return _allGather(int_data); + } else { + static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + } + } + template <typename DataType> PASTIS_INLINE DataType broadcast(const DataType& data, int root_rank) const @@ -99,16 +113,23 @@ void barrier() template <typename DataType> PASTIS_INLINE -Array<DataType> broadcast(const Array<DataType>& array, int root_rank) +DataType broadcast(const DataType& data, int root_rank) { - return messenger().broadcast(array, root_rank); + return messenger().broadcast(data, root_rank); } template <typename DataType> PASTIS_INLINE -DataType broadcast(const DataType& data, int root_rank) +Array<DataType> allGather(const DataType& data) { - return messenger().broadcast(data, root_rank); + return messenger().allGather(data); +} + +template <typename DataType> +PASTIS_INLINE +Array<DataType> broadcast(const Array<DataType>& array, int root_rank) +{ + return messenger().broadcast(array, root_rank); } #endif // MESSENGER_HPP