diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 63519598ad775976ae6a5102fc940812634900a7..12c2378939c2512d55815fc8dd57e27ead44de83 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -16,6 +16,8 @@ #include <mpi.h> #endif // PASTIS_HAS_MPI +#include <PastisTraits.hpp> + namespace parallel { @@ -337,7 +339,7 @@ class Messenger if constexpr(std::is_arithmetic_v<DataType>) { _allGather(data, gather_array); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; CastArray cast_value_array = cast_value_to<const CastType>::from(data); @@ -345,7 +347,7 @@ class Messenger _allGather(cast_value_array, cast_gather_array); } else { - static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); + static_assert(is_trivially_castable<DataType>, "unexpected type of data"); } return gather_array; } @@ -360,7 +362,7 @@ class Messenger if constexpr(std::is_arithmetic_v<DataType>) { _allGather(array, gather_array); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; using MutableCastType = helper::split_cast_t<MutableDataType>; @@ -369,7 +371,7 @@ class Messenger _allGather(cast_array, cast_gather_array); } else { - static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); + static_assert(is_trivially_castable<DataType>, "unexpected type of data"); } return gather_array; } @@ -391,14 +393,14 @@ class Messenger if constexpr(std::is_arithmetic_v<DataType>) { _allToAll(sent_array, recv_array); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; auto send_cast_array = cast_array_to<const CastType>::from(sent_array); auto recv_cast_array = cast_array_to<CastType>::from(recv_array); _allToAll(send_cast_array, recv_cast_array); } else { - static_assert(std::is_trivial_v<DataType>, "unexpected type of data"); + static_assert(is_trivially_castable<DataType>, "unexpected type of data"); } return recv_array; } @@ -411,7 +413,7 @@ class Messenger "cannot broadcast const data"); if constexpr(std::is_arithmetic_v<DataType>) { _broadcast_value(data, root_rank); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; if constexpr(sizeof(CastType) == sizeof(DataType)) { CastType& cast_data = reinterpret_cast<CastType&>(data); @@ -421,7 +423,7 @@ class Messenger _broadcast_array(cast_array, root_rank); } } else { - static_assert(std::is_trivial_v<DataType>, + static_assert(is_trivially_castable<DataType>, "unexpected non trivial type of data"); } } @@ -440,7 +442,7 @@ class Messenger array = Array<DataType>(size); // LCOV_EXCL_LINE } _broadcast_array(array, root_rank); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { size_t size = array.size(); _broadcast_value(size, root_rank); if (m_rank != root_rank) { @@ -451,7 +453,7 @@ class Messenger auto cast_array = cast_array_to<CastType>::from(array); _broadcast_array(cast_array, root_rank); } else{ - static_assert(std::is_trivial_v<DataType>, + static_assert(is_trivially_castable<DataType>, "unexpected non trivial type of data"); } } @@ -485,11 +487,11 @@ class Messenger if constexpr(std::is_arithmetic_v<DataType>) { _exchange(send_array_list, recv_array_list); - } else if constexpr(std::is_trivial_v<DataType>) { + } else if constexpr(is_trivially_castable<DataType>) { using CastType = helper::split_cast_t<DataType>; _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list); } else { - static_assert(std::is_trivial_v<RecvDataType>, + static_assert(is_trivially_castable<RecvDataType>, "unexpected non trivial type of data"); } } diff --git a/src/utils/PastisTraits.hpp b/src/utils/PastisTraits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5214fabe4ed54ea6f3f04fe790ac744e04c616b3 --- /dev/null +++ b/src/utils/PastisTraits.hpp @@ -0,0 +1,18 @@ +#ifndef PASTIS_TRAITS_HPP +#define PASTIS_TRAITS_HPP + +#include <type_traits> + +template <size_t N, typename T> class TinyVector; +template <size_t N, typename T> class TinyMatrix; + +template <typename T> +inline constexpr bool is_trivially_castable = std::is_trivial_v<T>; + +template <size_t N, typename T> +inline constexpr bool is_trivially_castable<TinyVector<N,T>> = is_trivially_castable<T>; + +template <size_t N, typename T> +inline constexpr bool is_trivially_castable<TinyMatrix<N,T>> = is_trivially_castable<T>; + +#endif // PASTIS_TRAITS_HPP