From 55a62ac0e846a4d8944b365aea68e4c92b683627 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 2 Oct 2018 11:40:10 +0200
Subject: [PATCH] Add allGather function for single value (ie. not array)

---
 src/utils/Messenger.cpp | 21 +++++++++++++++++++++
 src/utils/Messenger.hpp | 29 +++++++++++++++++++++++++----
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp
index 44d353e2b..5817a2a20 100644
--- a/src/utils/Messenger.cpp
+++ b/src/utils/Messenger.cpp
@@ -58,10 +58,29 @@ void Messenger::barrier() const
 #endif // PASTIS_HAS_MPI
 }
 
+Array<int> Messenger::
+_allGather(int& data) const
+{
+  Array<int> gather(m_size);
+
+#ifdef PASTIS_HAS_MPI
+  MPI_Allgather(&data, 1,  MPI_INT, &(gather[0]), 1,  MPI_INT,
+                MPI_COMM_WORLD);
+#else // PASTIS_HAS_MPI
+  gather[0] = data;
+#endif // PASTIS_HAS_MPI
+
+  return gather;
+}
+
+
+
 int Messenger::
 _broadcast(int& data, int root_rank) const
 {
+#ifdef PASTIS_HAS_MPI
   MPI_Bcast(&data, 1,  MPI_INT, root_rank, MPI_COMM_WORLD);
+#endif // PASTIS_HAS_MPI
   return data;
 }
 
@@ -69,11 +88,13 @@ _broadcast(int& data, int root_rank) const
 Array<int> Messenger::
 _broadcast(Array<int>& array, int root_rank) const
 {
+#ifdef PASTIS_HAS_MPI
   int size = array.size();
   _broadcast(size, root_rank);
   if (commRank() != root_rank) {
     array = Array<int>(size);
   }
   MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD);
+#endif // PASTIS_HAS_MPI
   return array;
 }
diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp
index 517be0d1c..6728b057f 100644
--- a/src/utils/Messenger.hpp
+++ b/src/utils/Messenger.hpp
@@ -15,6 +15,8 @@ class Messenger
   int m_rank{0};
   int m_size{1};
 
+  Array<int> _allGather(int& data) const;
+
   int _broadcast(int& data, int root_rank) const;
   Array<int> _broadcast(Array<int>& array, int root_rank) const;
 
@@ -43,6 +45,18 @@ class Messenger
 
   void barrier() const;
 
+  template <typename DataType>
+  PASTIS_INLINE
+  Array<DataType> allGather(const DataType& data) const
+  {
+    if constexpr(std::is_same<DataType, int>()) {
+      int int_data = data;
+      return _allGather(int_data);
+    } else {
+      static_assert(std::is_same<DataType, int>(), "unexpected type of data");
+    }
+  }
+
   template <typename DataType>
   PASTIS_INLINE
   DataType broadcast(const DataType& data, int root_rank) const
@@ -99,16 +113,23 @@ void barrier()
 
 template <typename DataType>
 PASTIS_INLINE
-Array<DataType> broadcast(const Array<DataType>& array, int root_rank)
+DataType broadcast(const DataType& data, int root_rank)
 {
-  return messenger().broadcast(array, root_rank);
+  return messenger().broadcast(data, root_rank);
 }
 
 template <typename DataType>
 PASTIS_INLINE
-DataType broadcast(const DataType& data, int root_rank)
+Array<DataType> allGather(const DataType& data)
 {
-  return messenger().broadcast(data, root_rank);
+  return messenger().allGather(data);
+}
+
+template <typename DataType>
+PASTIS_INLINE
+Array<DataType> broadcast(const Array<DataType>& array, int root_rank)
+{
+  return messenger().broadcast(array, root_rank);
 }
 
 #endif // MESSENGER_HPP
-- 
GitLab