Skip to content
Snippets Groups Projects
Commit ac491b1f authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Implement data exchange for datatype which size of 1,2,4 or 8 bytes

parent 553e23a5
Branches
Tags
1 merge request!11Feature/mpi
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <PastisAssert.hpp> #include <PastisAssert.hpp>
#include <Array.hpp> #include <Array.hpp>
#include <CastArray.hpp>
#include <type_traits> #include <type_traits>
...@@ -19,9 +20,9 @@ enum class CellType : unsigned short; ...@@ -19,9 +20,9 @@ enum class CellType : unsigned short;
class Messenger class Messenger
{ {
private: private:
#ifdef PASTIS_HAS_MPI
struct Helper struct Helper
{ {
#ifdef PASTIS_HAS_MPI
template<typename DataType> template<typename DataType>
static PASTIS_INLINE static PASTIS_INLINE
MPI_Datatype mpiType() MPI_Datatype mpiType()
...@@ -36,9 +37,49 @@ class Messenger ...@@ -36,9 +37,49 @@ class Messenger
return MPI_Datatype(); return MPI_Datatype();
} }
} }
};
#endif PASTIS_HAS_MPI #endif PASTIS_HAS_MPI
struct CompositeType {}; // composite type
template <typename data_type,
int size = sizeof(data_type)>
struct data_cast
{
using type = CompositeType;
};
template <typename data_type>
struct data_cast<data_type,1>
{
using type = int8_t;
static_assert(sizeof(data_type) == sizeof(type));
};
template <typename data_type>
struct data_cast<data_type,2>
{
using type = int16_t;
static_assert(sizeof(data_type) == sizeof(type));
};
template <typename data_type>
struct data_cast<data_type,4>
{
using type = int32_t;
static_assert(sizeof(data_type) == sizeof(type));
};
template <typename data_type>
struct data_cast<data_type,8>
{
using type = int64_t;
static_assert(sizeof(data_type) == sizeof(type));
};
template <typename data_type>
using data_cast_t = typename data_cast<data_type>::type;
};
static Messenger* m_instance; static Messenger* m_instance;
Messenger(int& argc, char* argv[]); Messenger(int& argc, char* argv[]);
...@@ -51,25 +92,24 @@ class Messenger ...@@ -51,25 +92,24 @@ class Messenger
Array<int> _broadcast(Array<int>& array, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const;
Array<int> _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const; Array<int> _allToAll(const Array<int>& sent_array, Array<int>& recv_array) const;
template <typename DataType> template <typename SendArrayType,
void _exchange(const std::vector<Array<DataType>>& sent_array_list, typename RecvArrayType>
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const void _exchange(const std::vector<SendArrayType>& sent_array_list,
std::vector<RecvArrayType>& recv_array_list) const
{ {
using SendDataType = typename SendArrayType::data_type;
using RecvDataType = typename RecvArrayType::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PASTIS_HAS_MPI #ifdef PASTIS_HAS_MPI
std::vector<MPI_Request> request_list; std::vector<MPI_Request> request_list;
#warning clean-up MPI_Datatype type = Messenger::Helper::mpiType<SendDataType>();
MPI_Datatype type = [&] () -> MPI_Datatype {
if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) {
return MPI_SHORT;
} else {
return Messenger::Helper::mpiType<DataType>();
}
} ();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const Array<DataType> sent_array = sent_array_list[i_send]; const SendArrayType sent_array = sent_array_list[i_send];
if (sent_array.size()>0) { if (sent_array.size()>0) {
MPI_Request request; MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request); MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request);
...@@ -78,7 +118,7 @@ class Messenger ...@@ -78,7 +118,7 @@ class Messenger
} }
for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) { for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv]; RecvArrayType recv_array = recv_array_list[i_recv];
if (recv_array.size()>0) { if (recv_array.size()>0) {
MPI_Request request; MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request); MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request);
...@@ -98,6 +138,24 @@ class Messenger ...@@ -98,6 +138,24 @@ class Messenger
#endif // PASTIS_HAS_MPI #endif // PASTIS_HAS_MPI
} }
template <typename DataType,
typename CastDataType>
void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list,
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
{
std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list;
for (size_t i=0; i<sent_array_list.size(); ++i) {
sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i]));
}
using MutableDataType = std::remove_const_t<DataType>;
std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list;
for (size_t i=0; i<sent_array_list.size(); ++i) {
recv_cast_array_list.emplace_back(recv_array_list[i]);
}
_exchange(sent_cast_array_list, recv_cast_array_list);
}
public: public:
static void create(int& argc, char* argv[]); static void create(int& argc, char* argv[]);
...@@ -183,13 +241,22 @@ class Messenger ...@@ -183,13 +241,22 @@ class Messenger
"send and receive data type do not match"); "send and receive data type do not match");
static_assert(not std::is_const_v<RecvDataType>, static_assert(not std::is_const_v<RecvDataType>,
"receive data type cannot be const"); "receive data type cannot be const");
using DataType = std::remove_const_t<SendDataType>;
if constexpr(std::is_arithmetic_v<RecvDataType>) { if constexpr(std::is_arithmetic_v<DataType>) {
_exchange(sent_array_list, recv_array_list);
} else if constexpr(std::is_same<RecvDataType, CellType>()) {
_exchange(sent_array_list, recv_array_list); _exchange(sent_array_list, recv_array_list);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = Helper::data_cast_t<DataType>;
if constexpr(std::is_same_v<CastType, Helper::CompositeType>) {
static_assert(not std::is_same_v<CastType, Helper::CompositeType>,
"treatment of composite type is not yet implemented!");
} else {
this->_exchange_through_cast<SendDataType, CastType>(sent_array_list, recv_array_list);
}
} else { } else {
static_assert(std::is_same<RecvDataType, int>(), "unexpected type of data"); static_assert(std::is_trivial_v<RecvDataType>,
"unexpected non trivial type of data");
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment