From 29451096054a38a3f593b04d1671bd8069b68dca Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Thu, 11 Oct 2018 19:20:27 +0200 Subject: [PATCH] Displace some code to header file All communication encapsulations are going to the header file to provide maximal generality (we want to exchange more than just Array<int>). --- src/utils/Messenger.cpp | 65 -------------------- src/utils/Messenger.hpp | 127 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 67 deletions(-) diff --git a/src/utils/Messenger.cpp b/src/utils/Messenger.cpp index fbc8a4ead..657bdb595 100644 --- a/src/utils/Messenger.cpp +++ b/src/utils/Messenger.cpp @@ -1,12 +1,6 @@ #include <Messenger.hpp> #include <PastisOStream.hpp> -#include <pastis_config.hpp> - -#ifdef PASTIS_HAS_MPI -#include <mpi.h> -#endif // PASTIS_HAS_MPI - Messenger* Messenger::m_instance = nullptr; void Messenger::create(int& argc, char* argv[]) @@ -113,62 +107,3 @@ _broadcast(Array<int>& array, int root_rank) const #endif // PASTIS_HAS_MPI return array; } - -template <typename DataType> -void Messenger:: -_exchange(const std::vector<Array<DataType>>& sent_array_list, - std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const -{ -#ifdef PASTIS_HAS_MPI - std::vector<MPI_Request> request_list; - - MPI_Datatype type = [&] () -> MPI_Datatype { - if constexpr (std::is_same_v<int,std::remove_const_t<DataType>>) { - return MPI_INT; - } else if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) { - return MPI_SHORT; - } - } (); - for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { - const Array<DataType> sent_array = sent_array_list[i_send]; - if (sent_array.size()>0) { - MPI_Request request; - MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); - request_list.push_back(request); - } - } - - for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { - Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv]; - if (recv_array.size()>0) { - MPI_Request request; - MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); - request_list.push_back(request); - } - } - - std::vector<MPI_Status> status_list(request_list.size()); - if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { - std::cerr << "Communication error!\n"; - std::exit(1); - } - -#else // PASTIS_HAS_MPI - std::cerr << "NIY\n"; - std::exit(1); -#endif // PASTIS_HAS_MPI -} - -template -void Messenger::_exchange(const std::vector<Array<int>>& sent_array_list, - std::vector<Array<int>>& recv_array_list) const; -template -void Messenger::_exchange(const std::vector<Array<const int>>& sent_array_list, - std::vector<Array<int>>& recv_array_list) const; - -template -void Messenger::_exchange(const std::vector<Array<CellType>>& sent_array_list, - std::vector<Array<CellType>>& recv_array_list) const; -template -void Messenger::_exchange(const std::vector<Array<const CellType>>& sent_array_list, - std::vector<Array<CellType>>& recv_array_list) const; diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index f12b54e02..aba291a53 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -6,12 +6,39 @@ #include <Array.hpp> +#include <type_traits> + +#include <pastis_config.hpp> +#ifdef PASTIS_HAS_MPI +#include <mpi.h> +#endif // PASTIS_HAS_MPI + #warning REMOVE enum class CellType : unsigned short; class Messenger { private: +#ifdef PASTIS_HAS_MPI + struct Helper + { + template<typename DataType> + static PASTIS_INLINE + MPI_Datatype mpiType() + { + if constexpr (std::is_const_v<DataType>) { + return mpiType<std::remove_const_t<DataType>>(); + } else { + static_assert(std::is_arithmetic_v<DataType>, + "Unexpected arithmetic type! Should not occur!"); + static_assert(not std::is_arithmetic_v<DataType>, + "MPI_Datatype are only defined for arithmetic types!"); + return MPI_Datatype(); + } + } + }; +#endif PASTIS_HAS_MPI + static Messenger* m_instance; Messenger(int& argc, char* argv[]); @@ -26,7 +53,51 @@ class Messenger template <typename DataType> void _exchange(const std::vector<Array<DataType>>& sent_array_list, - std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const; + std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const + { +#ifdef PASTIS_HAS_MPI + std::vector<MPI_Request> request_list; + +#warning clean-up + MPI_Datatype type = [&] () -> MPI_Datatype { + if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) { + return MPI_SHORT; + } else { + return Messenger::Helper::mpiType<DataType>(); + } + } (); + + + for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { + const Array<DataType> sent_array = sent_array_list[i_send]; + if (sent_array.size()>0) { + MPI_Request request; + MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); + request_list.push_back(request); + } + } + + for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { + Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv]; + if (recv_array.size()>0) { + MPI_Request request; + MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); + request_list.push_back(request); + } + } + + std::vector<MPI_Status> status_list(request_list.size()); + if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) { + std::cerr << "Communication error!\n"; + std::exit(1); + } + +#else // PASTIS_HAS_MPI + std::cerr << "NIY\n"; + std::exit(1); +#endif // PASTIS_HAS_MPI + } + public: static void create(int& argc, char* argv[]); @@ -113,7 +184,7 @@ class Messenger static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); - if constexpr(std::is_same<RecvDataType, int>()) { + if constexpr(std::is_arithmetic_v<RecvDataType>) { _exchange(sent_array_list, recv_array_list); } else if constexpr(std::is_same<RecvDataType, CellType>()) { _exchange(sent_array_list, recv_array_list); @@ -194,4 +265,56 @@ void exchange(const std::vector<Array<SendDataType>>& sent_array_list, messenger().exchange(sent_array_list, recv_array_list); } +#ifdef PASTIS_HAS_MPI + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<char>() {return MPI_CHAR; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<int8_t>() {return MPI_INT8_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<int16_t>() {return MPI_INT16_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<int32_t>() {return MPI_INT32_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<int64_t>() {return MPI_INT64_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<uint8_t>() {return MPI_UINT8_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<uint16_t>() {return MPI_UINT16_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<uint32_t>() {return MPI_UINT32_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<uint64_t>() {return MPI_UINT64_T; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<float>() {return MPI_FLOAT; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<double>() {return MPI_DOUBLE; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<long double>() {return MPI_LONG_DOUBLE; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<wchar_t>() {return MPI_WCHAR; } + +template<> PASTIS_INLINE MPI_Datatype +Messenger::Helper::mpiType<bool>() {return MPI_CXX_BOOL; } + +#endif // PASTIS_HAS_MPI + #endif // MESSENGER_HPP -- GitLab