diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 31272a89e4779020d3018b0b745c8f151adfec6d..437ea62a0d734274900c8af1b474749813e1aeaa 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -300,6 +300,7 @@ class Messenger { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); @@ -313,12 +314,49 @@ class Messenger #endif // PUGS_HAS_MPI } + template <typename DataType> + DataType + allReduceAnd(const DataType& data) const + { + static_assert(std::is_same_v<DataType, bool>); + +#ifdef PUGS_HAS_MPI + MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); + + DataType max_data = data; + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, MPI_COMM_WORLD); + + return max_data; +#else // PUGS_HAS_MPI + return data; +#endif // PUGS_HAS_MPI + } + + template <typename DataType> + DataType + allReduceOr(const DataType& data) const + { + static_assert(std::is_same_v<DataType, bool>); + +#ifdef PUGS_HAS_MPI + MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); + + DataType max_data = data; + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, MPI_COMM_WORLD); + + return max_data; +#else // PUGS_HAS_MPI + return data; +#endif // PUGS_HAS_MPI + } + template <typename DataType> DataType allReduceMax(const DataType& data) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); @@ -338,6 +376,7 @@ class Messenger { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); + static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI if constexpr (std::is_arithmetic_v<DataType>) { @@ -551,6 +590,20 @@ barrier() messenger().barrier(); } +template <typename DataType> +PUGS_INLINE DataType +allReduceAnd(const DataType& data) +{ + return messenger().allReduceAnd(data); +} + +template <typename DataType> +PUGS_INLINE DataType +allReduceOr(const DataType& data) +{ + return messenger().allReduceOr(data); +} + template <typename DataType> PUGS_INLINE DataType allReduceMax(const DataType& data) diff --git a/tests/mpi_test_Messenger.cpp b/tests/mpi_test_Messenger.cpp index beea7430bc83b17b11dc6ab2375f70d5a13e785a..19d89851c82020c0f10e4eaeac919595059e8fab 100644 --- a/tests/mpi_test_Messenger.cpp +++ b/tests/mpi_test_Messenger.cpp @@ -130,6 +130,18 @@ TEST_CASE("Messenger", "[mpi]") const int max_value = parallel::allReduceMax(parallel::rank() + 3); REQUIRE(max_value == static_cast<int>((parallel::size() - 1) + 3)); + + const bool and_value = parallel::allReduceAnd(true); + REQUIRE(and_value == true); + + const bool and_value_2 = parallel::allReduceAnd(parallel::rank() > 0); + REQUIRE(and_value_2 == false); + + const bool or_value = parallel::allReduceOr(false); + REQUIRE(or_value == false); + + const bool or_value_2 = parallel::allReduceOr(parallel::rank() > 0); + REQUIRE(or_value_2 == (parallel::size() > 1)); } SECTION("all to all")