From 51eb06379ecdba3b7d976ed7af93298be0589778 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 17 Jun 2020 21:29:44 +0200
Subject: [PATCH] Add parallel::allReduceAnd and parallel::allReduceOr

parallel::allReduceMax, parallel::allReduceMin, parallel::allReduceSum
can no longer be called on `bool` data, this is checked at compile
time.

These functions would fail at runtime (according to MPI standard)
---
 src/utils/Messenger.hpp      | 53 ++++++++++++++++++++++++++++++++++++
 tests/mpi_test_Messenger.cpp | 12 ++++++++
 2 files changed, 65 insertions(+)

diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp
index 31272a89e..437ea62a0 100644
--- a/src/utils/Messenger.hpp
+++ b/src/utils/Messenger.hpp
@@ -300,6 +300,7 @@ class Messenger
   {
     static_assert(not std::is_const_v<DataType>);
     static_assert(std::is_arithmetic_v<DataType>);
+    static_assert(not std::is_same_v<DataType, bool>);
 
 #ifdef PUGS_HAS_MPI
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
@@ -313,12 +314,49 @@ class Messenger
 #endif   // PUGS_HAS_MPI
   }
 
+  template <typename DataType>
+  DataType
+  allReduceAnd(const DataType& data) const
+  {
+    static_assert(std::is_same_v<DataType, bool>);
+
+#ifdef PUGS_HAS_MPI
+    MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
+
+    DataType max_data = data;
+    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD);
+
+    return max_data;
+#else    // PUGS_HAS_MPI
+    return data;
+#endif   // PUGS_HAS_MPI
+  }
+
+  template <typename DataType>
+  DataType
+  allReduceOr(const DataType& data) const
+  {
+    static_assert(std::is_same_v<DataType, bool>);
+
+#ifdef PUGS_HAS_MPI
+    MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
+
+    DataType max_data = data;
+    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD);
+
+    return max_data;
+#else    // PUGS_HAS_MPI
+    return data;
+#endif   // PUGS_HAS_MPI
+  }
+
   template <typename DataType>
   DataType
   allReduceMax(const DataType& data) const
   {
     static_assert(not std::is_const_v<DataType>);
     static_assert(std::is_arithmetic_v<DataType>);
+    static_assert(not std::is_same_v<DataType, bool>);
 
 #ifdef PUGS_HAS_MPI
     MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
@@ -338,6 +376,7 @@ class Messenger
   {
     static_assert(not std::is_const_v<DataType>);
     static_assert(std::is_arithmetic_v<DataType>);
+    static_assert(not std::is_same_v<DataType, bool>);
 
 #ifdef PUGS_HAS_MPI
     if constexpr (std::is_arithmetic_v<DataType>) {
@@ -551,6 +590,20 @@ barrier()
   messenger().barrier();
 }
 
+template <typename DataType>
+PUGS_INLINE DataType
+allReduceAnd(const DataType& data)
+{
+  return messenger().allReduceAnd(data);
+}
+
+template <typename DataType>
+PUGS_INLINE DataType
+allReduceOr(const DataType& data)
+{
+  return messenger().allReduceOr(data);
+}
+
 template <typename DataType>
 PUGS_INLINE DataType
 allReduceMax(const DataType& data)
diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp
index beea7430b..19d89851c 100644
--- a/tests/mpi_test_Messenger.cpp
+++ b/tests/mpi_test_Messenger.cpp
@@ -130,6 +130,18 @@ TEST_CASE("Messenger", "[mpi]")
 
     const int max_value = parallel::allReduceMax(parallel::rank() + 3);
     REQUIRE(max_value == static_cast<int>((parallel::size() - 1) + 3));
+
+    const bool and_value = parallel::allReduceAnd(true);
+    REQUIRE(and_value == true);
+
+    const bool and_value_2 = parallel::allReduceAnd(parallel::rank() > 0);
+    REQUIRE(and_value_2 == false);
+
+    const bool or_value = parallel::allReduceOr(false);
+    REQUIRE(or_value == false);
+
+    const bool or_value_2 = parallel::allReduceOr(parallel::rank() > 0);
+    REQUIRE(or_value_2 == (parallel::size() > 1));
   }
 
   SECTION("all to all")
-- 
GitLab