From 51eb06379ecdba3b7d976ed7af93298be0589778 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 17 Jun 2020 21:29:44 +0200 Subject: [PATCH] Add parallel::allReduceAnd and parallel::allReduceOr parallel::allReduceMax, parallel::allReduceMin, parallel::allReduceSum can no longer be called on `bool` data, this is checked at compile time. These functions would fail at runtime (according to MPI standard) --- src/utils/Messenger.hpp | 53 ++++++++++++++++++++++++++++++++++++ tests/mpi_test_Messenger.cpp | 12 ++++++++ 2 files changed, 65 insertions(+) diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 31272a89e..437ea62a0 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -300,6 +300,7 @@ class Messenger { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); @@ -313,12 +314,49 @@ class Messenger #endif // PUGS_HAS_MPI } + template <typename DataType> + DataType + allReduceAnd(const DataType& data) const + { + static_assert(std::is_same_v<DataType, bool>); + +#ifdef PUGS_HAS_MPI + MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); + + DataType max_data = data; + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD); + + return max_data; +#else // PUGS_HAS_MPI + return data; +#endif // PUGS_HAS_MPI + } + + template <typename DataType> + DataType + allReduceOr(const DataType& data) const + { + static_assert(std::is_same_v<DataType, bool>); + +#ifdef PUGS_HAS_MPI + MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); + + DataType max_data = data; + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD); + + return max_data; +#else // PUGS_HAS_MPI + return data; +#endif // PUGS_HAS_MPI + } + template <typename DataType> DataType allReduceMax(const DataType& data) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); @@ -338,6 +376,7 @@ class Messenger { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI if constexpr (std::is_arithmetic_v<DataType>) { @@ -551,6 +590,20 @@ barrier() messenger().barrier(); } +template <typename DataType> +PUGS_INLINE DataType +allReduceAnd(const DataType& data) +{ + return messenger().allReduceAnd(data); +} + +template <typename DataType> +PUGS_INLINE DataType +allReduceOr(const DataType& data) +{ + return messenger().allReduceOr(data); +} + template <typename DataType> PUGS_INLINE DataType allReduceMax(const DataType& data) diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index beea7430b..19d89851c 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -130,6 +130,18 @@ TEST_CASE("Messenger", "[mpi]") const int max_value = parallel::allReduceMax(parallel::rank() + 3); REQUIRE(max_value == static_cast<int>((parallel::size() - 1) + 3)); + + const bool and_value = parallel::allReduceAnd(true); + REQUIRE(and_value == true); + + const bool and_value_2 = parallel::allReduceAnd(parallel::rank() > 0); + REQUIRE(and_value_2 == false); + + const bool or_value = parallel::allReduceOr(false); + REQUIRE(or_value == false); + + const bool or_value_2 = parallel::allReduceOr(parallel::rank() > 0); + REQUIRE(or_value_2 == (parallel::size() > 1)); } SECTION("all to all") -- GitLab