#ifndef MESSENGER_HPP #define MESSENGER_HPP #include <PastisMacros.hpp> #include <PastisAssert.hpp> #include <Array.hpp> #include <CastArray.hpp> #include <ArrayUtils.hpp> #include <type_traits> #include <vector> #include <pastis_config.hpp> #ifdef PASTIS_HAS_MPI #include <mpi.h> #endif // PASTIS_HAS_MPI namespace parallel { class Messenger { private: struct helper { #ifdef PASTIS_HAS_MPI template<typename DataType> static PASTIS_INLINE MPI_Datatype mpiType() { if constexpr (std::is_const_v<DataType>) { return mpiType<std::remove_const_t<DataType>>(); } else { static_assert(std::is_arithmetic_v<DataType>, "Unexpected arithmetic type! Should not occur!"); static_assert(not std::is_arithmetic_v<DataType>, "MPI_Datatype are only defined for arithmetic types!"); return MPI_Datatype(); } } #endif // PASTIS_HAS_MPI private: template <typename T, typename Allowed = void> struct split_cast {}; template <typename T> struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>> { using type = int64_t; static_assert(not(sizeof(T) % sizeof(int64_t))); }; template <typename T> struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int32_t)) and(sizeof(T) % sizeof(int64_t))>> { using type = int32_t; static_assert(not(sizeof(T) % sizeof(int32_t))); }; template <typename T> struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int16_t)) and(sizeof(T) % sizeof(int32_t)) and(sizeof(T) % sizeof(int64_t))>> { using type = int16_t; static_assert(not(sizeof(T) % sizeof(int16_t))); }; template <typename T> struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int8_t)) and(sizeof(T) % sizeof(int16_t)) and(sizeof(T) % sizeof(int32_t)) and(sizeof(T) % sizeof(int64_t))>> { using type = int8_t; static_assert(not(sizeof(T) % sizeof(int8_t))); }; public: template <typename T> using split_cast_t = typename split_cast<T>::type; }; static Messenger* m_instance; Messenger(int& argc, char* argv[]); size_t m_rank{0}; size_t m_size{1}; template <typename DataType> void _allGather(const DataType& data, Array<DataType> gather) const { static_assert(std::is_arithmetic_v<DataType>); Assert(gather.size() == m_size); // LCOV_EXCL_LINE #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, MPI_COMM_WORLD); #else // PASTIS_HAS_MPI gather[0] = data; #endif // PASTIS_HAS_MPI } template <template <typename ...SendT> typename SendArrayType, template <typename ...RecvT> typename RecvArrayType, typename ...SendT, typename ...RecvT> void _allGather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array) const { Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>(); MPI_Allgather(&(data_array[0]), data_array.size(), mpi_datatype, &(gather_array[0]), data_array.size(), mpi_datatype, MPI_COMM_WORLD); #else // PASTIS_HAS_MPI value_copy(data_array, gather_array); #endif // PASTIS_HAS_MPI } template <typename DataType> void _broadcast_value(DataType& data, const size_t& root_rank) const { static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI } template <typename ArrayType> void _broadcast_array(ArrayType& array, const size_t& root_rank) const { using DataType = typename ArrayType::data_type; static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); #ifdef PASTIS_HAS_MPI MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD); #endif // PASTIS_HAS_MPI } template <template <typename ...SendT> typename SendArrayType, template <typename ...RecvT> typename RecvArrayType, typename ...SendT, typename ...RecvT> RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array, RecvArrayType<RecvT...>& recv_array) const { #ifdef PASTIS_HAS_MPI using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE const size_t count = sent_array.size()/m_size; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); MPI_Alltoall(&(sent_array[0]), count, mpi_datatype, &(recv_array[0]), count, mpi_datatype, MPI_COMM_WORLD); #else // PASTIS_HAS_MPI value_copy(sent_array, recv_array); #endif // PASTIS_HAS_MPI return recv_array; } template <template <typename ...SendT> typename SendArrayType, template <typename ...RecvT> typename RecvArrayType, typename ...SendT, typename ...RecvT> void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list, std::vector<RecvArrayType<RecvT...>>& recv_array_list) const { using SendDataType = typename SendArrayType<SendT...>::data_type; using RecvDataType = typename RecvArrayType<RecvT...>::data_type; static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); static_assert(std::is_arithmetic_v<SendDataType>); #ifdef PASTIS_HAS_MPI std::vector<MPI_Request> request_list; MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { const auto sent_array = sent_array_list[i_send]; if (sent_array.size()>0) { MPI_Request request; MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { auto recv_array = recv_array_list[i_recv]; if (recv_array.size()>0) { MPI_Request request; MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request); request_list.push_back(request); } } std::vector<MPI_Status> status_list(request_list.size()); if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { // LCOV_EXCL_START std::cerr << "Communication error!\n"; std::exit(1); // LCOV_EXCL_STOP } #else // PASTIS_HAS_MPI Assert(sent_array_list.size() == 1); Assert(recv_array_list.size() == 1); value_copy(sent_array_list[0], recv_array_list[0]); #endif // PASTIS_HAS_MPI } template <typename DataType, typename CastDataType> void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list, std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const { std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list; for (size_t i=0; i<sent_array_list.size(); ++i) { sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i])); } using MutableDataType = std::remove_const_t<DataType>; std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list; for (size_t i=0; i<sent_array_list.size(); ++i) { recv_cast_array_list.emplace_back(recv_array_list[i]); } _exchange(sent_cast_array_list, recv_cast_array_list); } public: static void create(int& argc, char* argv[]); static void destroy(); PASTIS_INLINE static Messenger& getInstance() { Assert(m_instance != nullptr); // LCOV_EXCL_LINE return *m_instance; } PASTIS_INLINE const size_t& rank() const { return m_rank; } PASTIS_INLINE const size_t& size() const { return m_size; } void barrier() const; template <typename DataType> DataType allReduceMin(const DataType& data) const { #ifdef PASTIS_HAS_MPI static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType min_data = data; MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD); return min_data; #else // PASTIS_HAS_MPI return data; #endif // PASTIS_HAS_MPI } template <typename DataType> DataType allReduceMax(const DataType& data) const { #ifdef PASTIS_HAS_MPI static_assert(not std::is_const_v<DataType>); static_assert(std::is_arithmetic_v<DataType>); MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>(); DataType max_data = data; MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD); return max_data; #else // PASTIS_HAS_MPI return data; #endif // PASTIS_HAS_MPI } template <typename DataType> PASTIS_INLINE Array<DataType> allGather(const DataType& data) const { static_assert(not std::is_const_v<DataType>); Array<DataType> gather_array(m_size); if constexpr(std::is_arithmetic_v<DataType>) { _allGather(data, gather_array); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; CastArray cast_value_array = cast_value_to<const CastType>::from(data); CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array); _allGather(cast_value_array, cast_gather_array); } else { static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename DataType> PASTIS_INLINE Array<std::remove_const_t<DataType>> allGather(const Array<DataType>& array) const { using MutableDataType = std::remove_const_t<DataType>; Array<MutableDataType> gather_array(m_size*array.size()); if constexpr(std::is_arithmetic_v<DataType>) { _allGather(array, gather_array); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; CastArray cast_array = cast_array_to<CastType>::from(array); CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array); _allGather(cast_array, cast_gather_array); } else { static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); } return gather_array; } template <typename SendDataType> PASTIS_INLINE Array<std::remove_const_t<SendDataType>> allToAll(const Array<SendDataType>& sent_array) const { #ifndef NDEBUG const size_t min_size = allReduceMin(sent_array.size()); const size_t max_size = allReduceMax(sent_array.size()); Assert(max_size == min_size); // LCOV_EXCL_LINE #endif // NDEBUG Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE using DataType = std::remove_const_t<SendDataType>; Array<DataType> recv_array(sent_array.size()); if constexpr(std::is_arithmetic_v<DataType>) { _allToAll(sent_array, recv_array); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; auto send_cast_array = cast_array_to<const CastType>::from(sent_array); auto recv_cast_array = cast_array_to<CastType>::from(recv_array); _allToAll(send_cast_array, recv_cast_array); } else { static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); } return recv_array; } template <typename DataType> PASTIS_INLINE void broadcast(DataType& data, const size_t& root_rank) const { static_assert(not std::is_const_v<DataType>, "cannot broadcast const data"); if constexpr(std::is_arithmetic_v<DataType>) { _broadcast_value(data, root_rank); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; if constexpr(sizeof(CastType) == sizeof(DataType)) { CastType& cast_data = reinterpret_cast<CastType&>(data); _broadcast_value(cast_data, root_rank); } else { CastArray cast_array = cast_value_to<CastType>::from(data); _broadcast_array(cast_array, root_rank); } } else { static_assert(std::is_trivial_v<DataType>, "unexpected non trivial type of data"); } } template <typename DataType> PASTIS_INLINE void broadcast(Array<DataType>& array, const size_t& root_rank) const { static_assert(not std::is_const_v<DataType>, "cannot broadcast array of const"); if constexpr(std::is_arithmetic_v<DataType>) { size_t size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { array = Array<DataType>(size); // LCOV_EXCL_LINE } _broadcast_array(array, root_rank); } else if constexpr(std::is_trivial_v<DataType>) { size_t size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { array = Array<DataType>(size); // LCOV_EXCL_LINE } using CastType = helper::split_cast_t<DataType>; auto cast_array = cast_array_to<CastType>::from(array); _broadcast_array(cast_array, root_rank); } else{ static_assert(std::is_trivial_v<DataType>, "unexpected non trivial type of data"); } } template <typename SendDataType, typename RecvDataType> PASTIS_INLINE void exchange(const std::vector<Array<SendDataType>>& send_array_list, std::vector<Array<RecvDataType>>& recv_array_list) const { static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>, "send and receive data type do not match"); static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); using DataType = std::remove_const_t<SendDataType>; Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE #ifndef NDEBUG Array<size_t> send_size(m_size); for (size_t i=0; i<m_size; ++i) { send_size[i] = send_array_list[i].size(); } Array<size_t> recv_size = allToAll(send_size); bool correct_sizes = true; for (size_t i=0; i<m_size; ++i) { correct_sizes &= (recv_size[i] == recv_array_list[i].size()); } Assert(correct_sizes); // LCOV_EXCL_LINE #endif // NDEBUG if constexpr(std::is_arithmetic_v<DataType>) { _exchange(send_array_list, recv_array_list); } else if constexpr(std::is_trivial_v<DataType>) { using CastType = helper::split_cast_t<DataType>; _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list); } else { static_assert(std::is_trivial_v<RecvDataType>, "unexpected non trivial type of data"); } } Messenger(const Messenger&) = delete; ~Messenger(); }; PASTIS_INLINE const Messenger& messenger() { return Messenger::getInstance(); } PASTIS_INLINE const size_t& rank() { return messenger().rank(); } PASTIS_INLINE const size_t& size() { return messenger().size(); } PASTIS_INLINE void barrier() { return messenger().barrier(); } template <typename DataType> PASTIS_INLINE DataType allReduceMax(const DataType& data) { return messenger().allReduceMax(data); } template <typename DataType> PASTIS_INLINE DataType allReduceMin(const DataType& data) { return messenger().allReduceMin(data); } template <typename DataType> PASTIS_INLINE Array<DataType> allGather(const DataType& data) { return messenger().allGather(data); } template <typename DataType> PASTIS_INLINE Array<std::remove_const_t<DataType>> allGather(const Array<DataType>& array) { return messenger().allGather(array); } template <typename DataType> PASTIS_INLINE Array<std::remove_const_t<DataType>> allToAll(const Array<DataType>& array) { return messenger().allToAll(array); } template <typename DataType> PASTIS_INLINE void broadcast(DataType& data, const size_t& root_rank) { messenger().broadcast(data, root_rank); } template <typename DataType> PASTIS_INLINE void broadcast(Array<DataType>& array, const size_t& root_rank) { messenger().broadcast(array, root_rank); } template <typename SendDataType, typename RecvDataType> PASTIS_INLINE void exchange(const std::vector<Array<SendDataType>>& sent_array_list, std::vector<Array<RecvDataType>>& recv_array_list) { static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>, "send and receive data type do not match"); static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); messenger().exchange(sent_array_list, recv_array_list); } #ifdef PASTIS_HAS_MPI template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<char>() {return MPI_CHAR; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<int8_t>() {return MPI_INT8_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<int16_t>() {return MPI_INT16_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<int32_t>() {return MPI_INT32_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<int64_t>() {return MPI_INT64_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<uint8_t>() {return MPI_UINT8_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<uint16_t>() {return MPI_UINT16_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<uint32_t>() {return MPI_UINT32_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<uint64_t>() {return MPI_UINT64_T; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<float>() {return MPI_FLOAT; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<double>() {return MPI_DOUBLE; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<long double>() {return MPI_LONG_DOUBLE; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<wchar_t>() {return MPI_WCHAR; } template<> PASTIS_INLINE MPI_Datatype Messenger::helper::mpiType<bool>() {return MPI_CXX_BOOL; } #endif // PASTIS_HAS_MPI } // namespace parallel #endif // MESSENGER_HPP