diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index aba291a534136bdde5ccf6385d8f540619d44814..bb8b58d52676251c4364cfe7f04d2eaa31939a68 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -5,6 +5,7 @@ #include <PastisAssert.hpp> #include <Array.hpp> +#include <CastArray.hpp> #include <type_traits> @@ -19,9 +20,9 @@ enum class CellType : unsigned short; class Messenger { private: -#ifdef PASTIS_HAS_MPI struct Helper { +#ifdef PASTIS_HAS_MPI template<typename DataType> static PASTIS_INLINE MPI_Datatype mpiType() @@ -36,9 +37,49 @@ class Messenger return MPI_Datatype(); } } - }; #endif PASTIS_HAS_MPI + struct CompositeType {}; // composite type + + template <typename data_type, + int size = sizeof(data_type)> + struct data_cast + { + using type = CompositeType; + }; + + template <typename data_type> + struct data_cast<data_type,1> + { + using type = int8_t; + static_assert(sizeof(data_type) == sizeof(type)); + }; + + template <typename data_type> + struct data_cast<data_type,2> + { + using type = int16_t; + static_assert(sizeof(data_type) == sizeof(type)); + }; + + template <typename data_type> + struct data_cast<data_type,4> + { + using type = int32_t; + static_assert(sizeof(data_type) == sizeof(type)); + }; + + template <typename data_type> + struct data_cast<data_type,8> + { + using type = int64_t; + static_assert(sizeof(data_type) == sizeof(type)); + }; + + template <typename data_type> + using data_cast_t = typename data_cast<data_type>::type; + }; + static Messenger* m_instance; Messenger(int& argc, char* argv[]); @@ -51,25 +92,24 @@ class Messenger Array<int> _broadcast(Array<int>& array, int root_rank) const; Array<int> _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const; - template <typename DataType> - void _exchange(const std::vector<Array<DataType>>& sent_array_list, - std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const + template <typename SendArrayType, + typename RecvArrayType> + void _exchange(const std::vector<SendArrayType>& sent_array_list, + std::vector<RecvArrayType>& recv_array_list) const { + using SendDataType = typename SendArrayType::data_type; + using RecvDataType = typename RecvArrayType::data_type; + + static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>); + static_assert(std::is_arithmetic_v<SendDataType>); + #ifdef PASTIS_HAS_MPI std::vector<MPI_Request> request_list; -#warning clean-up - MPI_Datatype type = [&] () -> MPI_Datatype { - if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) { - return MPI_SHORT; - } else { - return Messenger::Helper::mpiType<DataType>(); - } - } (); - + MPI_Datatype type = Messenger::Helper::mpiType<SendDataType>(); for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { - const Array<DataType> sent_array = sent_array_list[i_send]; + const SendArrayType sent_array = sent_array_list[i_send]; if (sent_array.size()>0) { MPI_Request request; MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); @@ -78,7 +118,7 @@ class Messenger } for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { - Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv]; + RecvArrayType recv_array = recv_array_list[i_recv]; if (recv_array.size()>0) { MPI_Request request; MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); @@ -98,6 +138,24 @@ class Messenger #endif // PASTIS_HAS_MPI } + template <typename DataType, + typename CastDataType> + void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list, + std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const + { + std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list; + for (size_t i=0; i<sent_array_list.size(); ++i) { + sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i])); + } + + using MutableDataType = std::remove_const_t<DataType>; + std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list; + for (size_t i=0; i<sent_array_list.size(); ++i) { + recv_cast_array_list.emplace_back(recv_array_list[i]); + } + + _exchange(sent_cast_array_list, recv_cast_array_list); + } public: static void create(int& argc, char* argv[]); @@ -183,13 +241,22 @@ class Messenger "send and receive data type do not match"); static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const"); + using DataType = std::remove_const_t<SendDataType>; - if constexpr(std::is_arithmetic_v<RecvDataType>) { - _exchange(sent_array_list, recv_array_list); - } else if constexpr(std::is_same<RecvDataType, CellType>()) { + if constexpr(std::is_arithmetic_v<DataType>) { _exchange(sent_array_list, recv_array_list); + } else if constexpr(std::is_trivial_v<DataType>) { + using CastType = Helper::data_cast_t<DataType>; + + if constexpr(std::is_same_v<CastType, Helper::CompositeType>) { + static_assert(not std::is_same_v<CastType, Helper::CompositeType>, + "treatment of composite type is not yet implemented!"); + } else { + this->_exchange_through_cast<SendDataType, CastType>(sent_array_list, recv_array_list); + } } else { - static_assert(std::is_same<RecvDataType, int>(), "unexpected type of data"); + static_assert(std::is_trivial_v<RecvDataType>, + "unexpected non trivial type of data"); } }