#ifndef MESSENGER_HPP #define MESSENGER_HPP #include <utils/PugsAssert.hpp> #include <utils/PugsMacros.hpp> #include <utils/Array.hpp> #include <utils/CastArray.hpp> #include <type_traits> #include <vector> #include <utils/pugs_config.hpp> #ifdef PUGS_HAS_MPI #include <mpi.h> #endif // PUGS_HAS_MPI #include <utils/Exceptions.hpp> #include <utils/PugsTraits.hpp> namespace parallel { class Messenger { private: struct helper { #ifdef PUGS_HAS_MPI template <typename DataType> static PUGS_INLINE MPI_Datatype mpiType() { if constexpr (std::is_const_v<DataType>) { return mpiType<std::remove_const_t<DataType>>(); } else { static_assert(std::is_arithmetic_v<DataType>, "Unexpected arithmetic type! Should not occur!"); static_assert(is_false_v<DataType>, "MPI_Datatype are only defined for arithmetic types!"); return MPI_Datatype(); } } #endif // PUGS_HAS_MPI private: template <typename T, typename Allowed = void> struct split_cast { }; template <typename T> struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>> { using type = int64_t; static_assert(not(sizeof(T) % sizeof(int64_t))); }; template <typename T> struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int32_t)) and (sizeof(T) % sizeof(int64_t))>> { using type = int32_t; static_assert(not(sizeof(T) % sizeof(int32_t))); }; template <typename T> struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int16_t)) and (sizeof(T) % sizeof(int32_t)) and (sizeof(T) % sizeof(int64_t))>> { using type = int16_t; static_assert(not(sizeof(T) % sizeof(int16_t))); }; template <typename T> struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int8_t)) and (sizeof(T) % sizeof(int16_t)) and (sizeof(T) % sizeof(int32_t)) and (sizeof(T) % sizeof(int64_t))>> { using type = int8_t; static_assert(not(sizeof(T) % sizeof(int8_t))); }; public: template <typename T> using split_cast_t = typename split_cast<T>::type; }; static Messenger* m_instance; Messenger(int& argc, char* argv[]); size_t m_rank{0}; size_t m_size{1}; template <typename DataType> void _gather(const DataType& data, Array<DataType> gather, size_t rank) const { static_assert(std::is_arithmetic_v<DataType>); Assert(gather.size() == m_size * (rank == m_rank)); // LCOV_EXCL_LINE #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); auto gather_address = (gather.size() > 0) ? &(gather[0]) : nullptr; MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, m_comm_world_pugs); #else // PUGS_HAS_MPI gather[0] = data; #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _gather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array, size_t rank) const { Assert(gather_array.size() == data_array.size() * m_size * (rank == m_rank)); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>(); auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr; auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Gather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, rank, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _gatherVariable(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array, Array<int> sizes_array, size_t rank) const { Assert(gather_array.size() - sum(sizes_array) * (rank == m_rank) == 0); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>(); Array<int> start_positions{sizes_array.size()}; if (start_positions.size() > 0) { start_positions[0] = 0; for (size_t i = 1; i < start_positions.size(); ++i) { start_positions[i] = start_positions[i - 1] + sizes_array[i - 1]; } } auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr; auto sizes_address = (sizes_array.size() > 0) ? &(sizes_array[0]) : nullptr; auto positions_address = (start_positions.size() > 0) ? &(start_positions[0]) : nullptr; auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Gatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address, mpi_datatype, rank, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI } template <typename DataType> void _allGather(const DataType& data, Array<DataType> gather) const { static_assert(std::is_arithmetic_v<DataType>); Assert(gather.size() == m_size); // LCOV_EXCL_LINE #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI gather[0] = data; #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _allGather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array) const { Assert(gather_array.size() == data_array.size() * m_size); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>(); auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr; auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Allgather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _allGatherVariable(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array, Array<int> sizes_array) const { Assert(gather_array.size() == static_cast<size_t>(sum(sizes_array))); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>(); Array<int> start_positions{sizes_array.size()}; if (start_positions.size() > 0) { start_positions[0] = 0; for (size_t i = 1; i < start_positions.size(); ++i) { start_positions[i] = start_positions[i - 1] + sizes_array[i - 1]; } } auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr; auto sizes_address = (sizes_array.size() > 0) ? &(sizes_array[0]) : nullptr; auto positions_address = (start_positions.size() > 0) ? &(start_positions[0]) : nullptr; auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr; MPI_Allgatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address, mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(data_array, gather_array); #endif // PUGS_HAS_MPI } template <typename DataType> void _broadcast_value([[maybe_unused]] DataType& data, [[maybe_unused]] size_t root_rank) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Bcast(&data, 1, mpi_datatype, root_rank, m_comm_world_pugs); #endif // PUGS_HAS_MPI } template <typename ArrayType> void _broadcast_array([[maybe_unused]] ArrayType& array, [[maybe_unused]] size_t root_rank) const { using DataType = typename ArrayType::data_type; static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); #ifdef PUGS_HAS_MPI auto array_address = (array.size() > 0) ? &(array[0]) : nullptr; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, m_comm_world_pugs); #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _allToAll(const SendArrayType<SendT...>& sent_array, RecvArrayType<RecvT...>& recv_array) const { #ifdef PUGS_HAS_MPI using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE const size_t count = sent_array.size() / m_size; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); auto sent_address = (sent_array.size() > 0) ? &(sent_array[0]) : nullptr; auto recv_address = (recv_array.size() > 0) ? &(recv_array[0]) : nullptr; MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, m_comm_world_pugs); #else // PUGS_HAS_MPI copy_to(sent_array, recv_array); #endif // PUGS_HAS_MPI } template <template <typename... SendT> typename SendArrayType, template <typename... RecvT> typename RecvArrayType, typename... SendT, typename... RecvT> void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list, std::vector<RecvArrayType<RecvT...>>& recv_array_list) const { using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PUGS_HAS_MPI std::vector<MPI_Request> request_list; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); for (size_t i_send = 0; i_send < sent_array_list.size(); ++i_send) { const auto sent_array = sent_array_list[i_send]; if (sent_array.size() > 0) { MPI_Request request; MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, m_comm_world_pugs, &request); request_list.push_back(request); } } for (size_t i_recv = 0; i_recv < recv_array_list.size(); ++i_recv) { auto recv_array = recv_array_list[i_recv]; if (recv_array.size() > 0) { MPI_Request request; MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, m_comm_world_pugs, &request); request_list.push_back(request); } } if (request_list.size() > 0) { std::vector<MPI_Status> status_list(request_list.size()); if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { // LCOV_EXCL_START throw NormalError("Communication error"); // LCOV_EXCL_STOP } } #else // PUGS_HAS_MPI Assert(sent_array_list.size() == 1); Assert(recv_array_list.size() == 1); copy_to(sent_array_list[0], recv_array_list[0]); #endif // PUGS_HAS_MPI } template <typename DataType, typename CastDataType> void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list, std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const { std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list; for (size_t i = 0; i < sent_array_list.size(); ++i) { sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i])); } using MutableDataType = std::remove_const_t<DataType>; std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list; for (size_t i = 0; i < sent_array_list.size(); ++i) { recv_cast_array_list.emplace_back(recv_array_list[i]); } _exchange(sent_cast_array_list, recv_cast_array_list); } public: #ifdef PUGS_HAS_MPI MPI_Comm m_comm_world_pugs = MPI_COMM_NULL; #endif // PUGS_HAS_MPI static void create(int& argc, char* argv[]); static void destroy(); PUGS_INLINE static Messenger& getInstance() { Assert(m_instance != nullptr); // LCOV_EXCL_LINE return *m_instance; } PUGS_INLINE const size_t& rank() const { return m_rank; } #ifdef PUGS_HAS_MPI PUGS_INLINE const MPI_Comm& comm() const { return m_comm_world_pugs; } #endif // PUGS_HAS_MPI PUGS_INLINE const size_t& size() const { return m_size; } void barrier() const; template <typename DataType> DataType allReduceMin(const DataType& data) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType min_data = data; MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, m_comm_world_pugs); return min_data; #else // PUGS_HAS_MPI return data; #endif // PUGS_HAS_MPI } template <typename DataType> DataType allReduceAnd(const DataType& data) const { static_assert(std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI return data; #endif // PUGS_HAS_MPI } template <typename DataType> DataType allReduceOr(const DataType& data) const { static_assert(std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI return data; #endif // PUGS_HAS_MPI } template <typename DataType> DataType allReduceMax(const DataType& data) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, m_comm_world_pugs); return max_data; #else // PUGS_HAS_MPI return data; #endif // PUGS_HAS_MPI } template <typename DataType> DataType allReduceSum(const DataType& data) const { static_assert(not std::is_const_v<DataType>); static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI if constexpr (std::is_arithmetic_v<DataType>) { MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType data_sum = data; MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, m_comm_world_pugs); return data_sum; } else if (is_trivially_castable<DataType>) { using InnerDataType = typename DataType::data_type; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<InnerDataType>(); const int size = sizeof(DataType) / sizeof(InnerDataType); DataType data_sum = data; MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, m_comm_world_pugs); return data_sum; } #else // PUGS_HAS_MPI return data; #endif // PUGS_HAS_MPI } template <typename DataType> PUGS_INLINE Array<DataType> gather(const DataType& data, size_t rank) const { static_assert(not std::is_const_v<DataType>); Array<DataType> gather_array((m_rank == rank) ? m_size : 0); if constexpr (std::is_arithmetic_v<DataType>) { _gather(data, gather_array, rank); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; CastArray cast_value_array = cast_value_to<const CastType>::from(data); CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array); _gather(cast_value_array, cast_gather_array, rank); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> gather(const Array<DataType>& array, size_t rank) const { using MutableDataType = std::remove_const_t<DataType>; Array<MutableDataType> gather_array((m_rank == rank) ? (m_size * array.size()) : 0); if constexpr (std::is_arithmetic_v<DataType>) { _gather(array, gather_array, rank); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; CastArray cast_array = cast_array_to<CastType>::from(array); CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); _gather(cast_array, cast_gather_array, rank); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> gatherVariable(const Array<DataType>& array, size_t rank) const { int send_size = array.size(); Array<int> sizes_array = gather(send_size, rank); using MutableDataType = std::remove_const_t<DataType>; Array<MutableDataType> gather_array(sum(sizes_array)); if constexpr (std::is_arithmetic_v<DataType>) { _gatherVariable(array, gather_array, sizes_array, rank); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; int size_ratio = sizeof(DataType) / sizeof(CastType); for (size_t i = 0; i < sizes_array.size(); ++i) { sizes_array[i] *= size_ratio; } CastArray cast_array = cast_array_to<CastType>::from(array); CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); _gatherVariable(cast_array, cast_gather_array, sizes_array, rank); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PUGS_INLINE Array<DataType> allGather(const DataType& data) const { static_assert(not std::is_const_v<DataType>); Array<DataType> gather_array(m_size); if constexpr (std::is_arithmetic_v<DataType>) { _allGather(data, gather_array); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; CastArray cast_value_array = cast_value_to<const CastType>::from(data); CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array); _allGather(cast_value_array, cast_gather_array); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> allGather(const Array<DataType>& array) const { using MutableDataType = std::remove_const_t<DataType>; Array<MutableDataType> gather_array(m_size * array.size()); if constexpr (std::is_arithmetic_v<DataType>) { _allGather(array, gather_array); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; CastArray cast_array = cast_array_to<CastType>::from(array); CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); _allGather(cast_array, cast_gather_array); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> allGatherVariable(const Array<DataType>& array) const { int send_size = array.size(); Array<int> sizes_array = allGather(send_size); using MutableDataType = std::remove_const_t<DataType>; Array<MutableDataType> gather_array(sum(sizes_array)); if constexpr (std::is_arithmetic_v<DataType>) { _allGatherVariable(array, gather_array, sizes_array); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; int size_ratio = sizeof(DataType) / sizeof(CastType); for (size_t i = 0; i < sizes_array.size(); ++i) { sizes_array[i] *= size_ratio; } CastArray cast_array = cast_array_to<CastType>::from(array); CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); _allGatherVariable(cast_array, cast_gather_array, sizes_array); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename SendDataType> PUGS_INLINE Array<std::remove_const_t<SendDataType>> allToAll(const Array<SendDataType>& sent_array) const { #ifndef NDEBUG const size_t min_size = allReduceMin(sent_array.size()); const size_t max_size = allReduceMax(sent_array.size()); Assert(max_size == min_size); // LCOV_EXCL_LINE #endif // NDEBUG Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE using DataType = std::remove_const_t<SendDataType>; Array<DataType> recv_array(sent_array.size()); if constexpr (std::is_arithmetic_v<DataType>) { _allToAll(sent_array, recv_array); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; auto send_cast_array = cast_array_to<const CastType>::from(sent_array); auto recv_cast_array = cast_array_to<CastType>::from(recv_array); _allToAll(send_cast_array, recv_cast_array); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } return recv_array; } template <typename DataType> PUGS_INLINE void broadcast(DataType& data, size_t root_rank) const { static_assert(not std::is_const_v<DataType>, "cannot broadcast const data"); if constexpr (std::is_arithmetic_v<DataType>) { _broadcast_value(data, root_rank); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; if constexpr (sizeof(CastType) == sizeof(DataType)) { CastType& cast_data = reinterpret_cast<CastType&>(data); _broadcast_value(cast_data, root_rank); } else { CastArray cast_array = cast_value_to<CastType>::from(data); _broadcast_array(cast_array, root_rank); } } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } } template <typename DataType> PUGS_INLINE void broadcast(Array<DataType>& array, size_t root_rank) const { static_assert(not std::is_const_v<DataType>, "cannot broadcast array of const"); if constexpr (std::is_arithmetic_v<DataType>) { size_t size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { array = Array<DataType>(size); // LCOV_EXCL_LINE } _broadcast_array(array, root_rank); } else if constexpr (is_trivially_castable<DataType>) { size_t size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { array = Array<DataType>(size); // LCOV_EXCL_LINE } using CastType = helper::split_cast_t<DataType>; auto cast_array = cast_array_to<CastType>::from(array); _broadcast_array(cast_array, root_rank); } else { static_assert(is_false_v<DataType>, "unexpected type of data"); } } template <typename SendDataType, typename RecvDataType> PUGS_INLINE void exchange(const std::vector<Array<SendDataType>>& send_array_list, std::vector<Array<RecvDataType>>& recv_array_list) const { static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>, "send and receive data type do not match"); static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); using DataType = std::remove_const_t<SendDataType>; Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE #ifndef NDEBUG Array<size_t> send_size(m_size); for (size_t i = 0; i < m_size; ++i) { send_size[i] = send_array_list[i].size(); } Array<size_t> recv_size = allToAll(send_size); bool correct_sizes = true; for (size_t i = 0; i < m_size; ++i) { correct_sizes &= (recv_size[i] == recv_array_list[i].size()); } Assert(correct_sizes, "incompatible send/recv messages length"); // LCOV_EXCL_LINE #endif // NDEBUG if constexpr (std::is_arithmetic_v<DataType>) { _exchange(send_array_list, recv_array_list); } else if constexpr (is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list); } else { static_assert(is_false_v<RecvDataType>, "unexpected type of data"); } } Messenger(const Messenger&) = delete; ~Messenger(); }; PUGS_INLINE const Messenger& messenger() { return Messenger::getInstance(); } PUGS_INLINE const size_t& rank() { return messenger().rank(); } PUGS_INLINE const size_t& size() { return messenger().size(); } PUGS_INLINE void barrier() { messenger().barrier(); } template <typename DataType> PUGS_INLINE DataType allReduceAnd(const DataType& data) { return messenger().allReduceAnd(data); } template <typename DataType> PUGS_INLINE DataType allReduceOr(const DataType& data) { return messenger().allReduceOr(data); } template <typename DataType> PUGS_INLINE DataType allReduceMax(const DataType& data) { return messenger().allReduceMax(data); } template <typename DataType> PUGS_INLINE DataType allReduceMin(const DataType& data) { return messenger().allReduceMin(data); } template <typename DataType> PUGS_INLINE DataType allReduceSum(const DataType& data) { return messenger().allReduceSum(data); } template <typename DataType> PUGS_INLINE Array<DataType> gather(const DataType& data, size_t rank) { return messenger().gather(data, rank); } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> gather(const Array<DataType>& array, size_t rank) { return messenger().gather(array, rank); } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> gatherVariable(const Array<DataType>& array, size_t rank) { return messenger().gatherVariable(array, rank); } template <typename DataType> PUGS_INLINE Array<DataType> allGather(const DataType& data) { return messenger().allGather(data); } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> allGather(const Array<DataType>& array) { return messenger().allGather(array); } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> allGatherVariable(const Array<DataType>& array) { return messenger().allGatherVariable(array); } template <typename DataType> PUGS_INLINE Array<std::remove_const_t<DataType>> allToAll(const Array<DataType>& array) { return messenger().allToAll(array); } template <typename DataType> PUGS_INLINE void broadcast(DataType& data, size_t root_rank) { messenger().broadcast(data, root_rank); } template <typename DataType> PUGS_INLINE void broadcast(Array<DataType>& array, size_t root_rank) { messenger().broadcast(array, root_rank); } template <typename SendDataType, typename RecvDataType> PUGS_INLINE void exchange(const std::vector<Array<SendDataType>>& sent_array_list, std::vector<Array<RecvDataType>>& recv_array_list) { static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>, "send and receive data type do not match"); static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); messenger().exchange(sent_array_list, recv_array_list); } #ifdef PUGS_HAS_MPI template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<char>() { return MPI_CHAR; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<int8_t>() { return MPI_INT8_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<int16_t>() { return MPI_INT16_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<int32_t>() { return MPI_INT32_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<int64_t>() { return MPI_INT64_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<uint8_t>() { return MPI_UINT8_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<uint16_t>() { return MPI_UINT16_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<uint32_t>() { return MPI_UINT32_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<uint64_t>() { return MPI_UINT64_T; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<signed long long int>() { return MPI_LONG_LONG_INT; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<unsigned long long int>() { return MPI_UNSIGNED_LONG_LONG; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<float>() { return MPI_FLOAT; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<double>() { return MPI_DOUBLE; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<long double>() { return MPI_LONG_DOUBLE; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<wchar_t>() { return MPI_WCHAR; } template <> PUGS_INLINE MPI_Datatype Messenger::helper::mpiType<bool>() { return MPI_CXX_BOOL; } #endif // PUGS_HAS_MPI } // namespace parallel #endif // MESSENGER_HPP