diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 09dc1f5ddef2a7be4f47e26dc295ce066a6f5b9d..32dd632bce9392c8a74353c0c7ffe749345f3154 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -87,7 +87,7 @@ class Messenger Array<DataType> gather) const { static_assert(std::is_arithmetic_v<DataType>); - Assert(gather.size() == m_size); + Assert(gather.size() == m_size); // LCOV_EXCL_LINE #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype @@ -109,7 +109,7 @@ class Messenger void _allGather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array) const { - Assert(gather_array.size() == data_array.size()*m_size); + Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; @@ -170,8 +170,8 @@ class Messenger static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); - Assert((sent_array.size() % m_size) == 0); - Assert(sent_array.size() == recv_array.size()); + Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE + Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE const int count = sent_array.size()/m_size; @@ -261,7 +261,7 @@ class Messenger PASTIS_INLINE static Messenger& getInstance() { - Assert(m_instance != nullptr); + Assert(m_instance != nullptr); // LCOV_EXCL_LINE return *m_instance; } @@ -279,6 +279,44 @@ class Messenger void barrier() const; + template <typename DataType> + DataType allReduceMin(const DataType& data) const + { +#ifdef PASTIS_HAS_MPI + static_assert(not std::is_const_v<DataType>); + static_assert(std::is_arithmetic_v<DataType>); + + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); + + DataType min_data = data; + MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD); + + return min_data; +#else // PASTIS_HAS_MPI + return data; +#endif // PASTIS_HAS_MPI + } + + template <typename DataType> + DataType allReduceMax(const DataType& data) const + { +#ifdef PASTIS_HAS_MPI + static_assert(not std::is_const_v<DataType>); + static_assert(std::is_arithmetic_v<DataType>); + + MPI_Datatype mpi_datatype + = Messenger::helper::mpiType<DataType>(); + + DataType max_data = data; + MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD); + + return max_data; +#else // PASTIS_HAS_MPI + return data; +#endif // PASTIS_HAS_MPI + } + template <typename DataType> PASTIS_INLINE Array<DataType> @@ -332,10 +370,16 @@ class Messenger Array<std::remove_const_t<SendDataType>> allToAll(const Array<SendDataType>& sent_array) const { - Assert((sent_array.size() % m_size) == 0); - using DataType = std::remove_const_t<SendDataType>; +#ifndef NDEBUG + const size_t min_size = allReduceMin(sent_array.size()); + const size_t max_size = allReduceMax(sent_array.size()); + Assert(max_size == min_size); // LCOV_EXCL_LINE +#endif // NDEBUG + Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE + using DataType = std::remove_const_t<SendDataType>; Array<DataType> recv_array(sent_array.size()); + if constexpr(std::is_arithmetic_v<DataType>) { _allToAll(sent_array, recv_array); } else if constexpr(std::is_trivial_v<DataType>) { @@ -383,14 +427,14 @@ class Messenger int size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { - array = Array<DataType>(size); + array = Array<DataType>(size); // LCOV_EXCL_LINE } _broadcast_array(array, root_rank); } else if constexpr(std::is_trivial_v<DataType>) { int size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { - array = Array<DataType>(size); + array = Array<DataType>(size); // LCOV_EXCL_LINE } using CastType = helper::split_cast_t<DataType>; @@ -455,6 +499,20 @@ void barrier() return messenger().barrier(); } +template <typename DataType> +PASTIS_INLINE +DataType allReduceMax(const DataType& data) +{ + return messenger().allReduceMax(data); +} + +template <typename DataType> +PASTIS_INLINE +DataType allReduceMin(const DataType& data) +{ + return messenger().allReduceMin(data); +} + template <typename DataType> PASTIS_INLINE Array<DataType>