From 55a62ac0e846a4d8944b365aea68e4c92b683627 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 2 Oct 2018 11:40:10 +0200 Subject: [PATCH] Add allGather function for single value (ie. not array) --- src/utils/Messenger.cpp | 21 +++++++++++++++++++++ src/utils/Messenger.hpp | 29 +++++++++++++++++++++++++---- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index 44d353e2b..5817a2a20 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -58,10 +58,29 @@ void Messenger::barrier() const #endif // PASTIS_HAS_MPI } +Array<int> Messenger:: +_allGather(int& data) const +{ + Array<int> gather(m_size); + +#ifdef PASTIS_HAS_MPI + MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT, + MPI_COMM_WORLD); +#else // PASTIS_HAS_MPI + gather[0] = data; +#endif // PASTIS_HAS_MPI + + return gather; +} + + + int Messenger:: _broadcast(int& data, int root_rank) const { +#ifdef PASTIS_HAS_MPI MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI return data; } @@ -69,11 +88,13 @@ _broadcast(int& data, int root_rank) const Array<int> Messenger:: _broadcast(Array<int>& array, int root_rank) const { +#ifdef PASTIS_HAS_MPI int size = array.size(); _broadcast(size, root_rank); if (commRank() != root_rank) { array = Array<int>(size); } MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD); +#endif // PASTIS_HAS_MPI return array; } diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 517be0d1c..6728b057f 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -15,6 +15,8 @@ class Messenger int m_rank{0}; int m_size{1}; + Array<int> _allGather(int& data) const; + int _broadcast(int& data, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const; @@ -43,6 +45,18 @@ class Messenger void barrier() const; + template <typename DataType> + PASTIS_INLINE + Array<DataType> allGather(const DataType& data) const + { + if constexpr(std::is_same<DataType, int>()) { + int int_data = data; + return _allGather(int_data); + } else { + static_assert(std::is_same<DataType, int>(), "unexpected type of data"); + } + } + template <typename DataType> PASTIS_INLINE DataType broadcast(const DataType& data, int root_rank) const @@ -99,16 +113,23 @@ void barrier() template <typename DataType> PASTIS_INLINE -Array<DataType> broadcast(const Array<DataType>& array, int root_rank) +DataType broadcast(const DataType& data, int root_rank) { - return messenger().broadcast(array, root_rank); + return messenger().broadcast(data, root_rank); } template <typename DataType> PASTIS_INLINE -DataType broadcast(const DataType& data, int root_rank) +Array<DataType> allGather(const DataType& data) { - return messenger().broadcast(data, root_rank); + return messenger().allGather(data); +} + +template <typename DataType> +PASTIS_INLINE +Array<DataType> broadcast(const Array<DataType>& array, int root_rank) +{ + return messenger().broadcast(array, root_rank); } #endif // MESSENGER_HPP -- GitLab