Skip to content
Snippets Groups Projects
Commit 29451096 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Displace some code to header file

All communication encapsulations are going to the header file to provide maximal
generality (we want to exchange more than just Array<int>).
parent 97e77bd8
No related branches found
No related tags found
1 merge request!11Feature/mpi
#include <Messenger.hpp> #include <Messenger.hpp>
#include <PastisOStream.hpp> #include <PastisOStream.hpp>
#include <pastis_config.hpp>
#ifdef PASTIS_HAS_MPI
#include <mpi.h>
#endif // PASTIS_HAS_MPI
Messenger* Messenger::m_instance = nullptr; Messenger* Messenger::m_instance = nullptr;
void Messenger::create(int& argc, char* argv[]) void Messenger::create(int& argc, char* argv[])
...@@ -113,62 +107,3 @@ _broadcast(Array<int>& array, int root_rank) const ...@@ -113,62 +107,3 @@ _broadcast(Array<int>& array, int root_rank) const
#endif // PASTIS_HAS_MPI #endif // PASTIS_HAS_MPI
return array; return array;
} }
template <typename DataType>
void Messenger::
_exchange(const std::vector<Array<DataType>>& sent_array_list,
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
{
#ifdef PASTIS_HAS_MPI
std::vector<MPI_Request> request_list;
MPI_Datatype type = [&] () -> MPI_Datatype {
if constexpr (std::is_same_v<int,std::remove_const_t<DataType>>) {
return MPI_INT;
} else if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) {
return MPI_SHORT;
}
} ();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const Array<DataType> sent_array = sent_array_list[i_send];
if (sent_array.size()>0) {
MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv];
if (recv_array.size()>0) {
MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
std::vector<MPI_Status> status_list(request_list.size());
if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
std::cerr << "Communication error!\n";
std::exit(1);
}
#else // PASTIS_HAS_MPI
std::cerr << "NIY\n";
std::exit(1);
#endif // PASTIS_HAS_MPI
}
template
void Messenger::_exchange(const std::vector<Array<int>>& sent_array_list,
std::vector<Array<int>>& recv_array_list) const;
template
void Messenger::_exchange(const std::vector<Array<const int>>& sent_array_list,
std::vector<Array<int>>& recv_array_list) const;
template
void Messenger::_exchange(const std::vector<Array<CellType>>& sent_array_list,
std::vector<Array<CellType>>& recv_array_list) const;
template
void Messenger::_exchange(const std::vector<Array<const CellType>>& sent_array_list,
std::vector<Array<CellType>>& recv_array_list) const;
...@@ -6,12 +6,39 @@ ...@@ -6,12 +6,39 @@
#include <Array.hpp> #include <Array.hpp>
#include <type_traits>
#include <pastis_config.hpp>
#ifdef PASTIS_HAS_MPI
#include <mpi.h>
#endif // PASTIS_HAS_MPI
#warning REMOVE #warning REMOVE
enum class CellType : unsigned short; enum class CellType : unsigned short;
class Messenger class Messenger
{ {
private: private:
#ifdef PASTIS_HAS_MPI
struct Helper
{
template<typename DataType>
static PASTIS_INLINE
MPI_Datatype mpiType()
{
if constexpr (std::is_const_v<DataType>) {
return mpiType<std::remove_const_t<DataType>>();
} else {
static_assert(std::is_arithmetic_v<DataType>,
"Unexpected arithmetic type! Should not occur!");
static_assert(not std::is_arithmetic_v<DataType>,
"MPI_Datatype are only defined for arithmetic types!");
return MPI_Datatype();
}
}
};
#endif PASTIS_HAS_MPI
static Messenger* m_instance; static Messenger* m_instance;
Messenger(int& argc, char* argv[]); Messenger(int& argc, char* argv[]);
...@@ -26,7 +53,51 @@ class Messenger ...@@ -26,7 +53,51 @@ class Messenger
template <typename DataType> template <typename DataType>
void _exchange(const std::vector<Array<DataType>>& sent_array_list, void _exchange(const std::vector<Array<DataType>>& sent_array_list,
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const; std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
{
#ifdef PASTIS_HAS_MPI
std::vector<MPI_Request> request_list;
#warning clean-up
MPI_Datatype type = [&] () -> MPI_Datatype {
if constexpr (std::is_same_v<CellType,std::remove_const_t<DataType>>) {
return MPI_SHORT;
} else {
return Messenger::Helper::mpiType<DataType>();
}
} ();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const Array<DataType> sent_array = sent_array_list[i_send];
if (sent_array.size()>0) {
MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), type, i_send, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
Array<std::remove_const_t<DataType>> recv_array = recv_array_list[i_recv];
if (recv_array.size()>0) {
MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), type, i_recv, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
std::vector<MPI_Status> status_list(request_list.size());
if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
std::cerr << "Communication error!\n";
std::exit(1);
}
#else // PASTIS_HAS_MPI
std::cerr << "NIY\n";
std::exit(1);
#endif // PASTIS_HAS_MPI
}
public: public:
static void create(int& argc, char* argv[]); static void create(int& argc, char* argv[]);
...@@ -113,7 +184,7 @@ class Messenger ...@@ -113,7 +184,7 @@ class Messenger
static_assert(not std::is_const_v<RecvDataType>, static_assert(not std::is_const_v<RecvDataType>,
"receive data type cannot be const"); "receive data type cannot be const");
if constexpr(std::is_same<RecvDataType, int>()) { if constexpr(std::is_arithmetic_v<RecvDataType>) {
_exchange(sent_array_list, recv_array_list); _exchange(sent_array_list, recv_array_list);
} else if constexpr(std::is_same<RecvDataType, CellType>()) { } else if constexpr(std::is_same<RecvDataType, CellType>()) {
_exchange(sent_array_list, recv_array_list); _exchange(sent_array_list, recv_array_list);
...@@ -194,4 +265,56 @@ void exchange(const std::vector<Array<SendDataType>>& sent_array_list, ...@@ -194,4 +265,56 @@ void exchange(const std::vector<Array<SendDataType>>& sent_array_list,
messenger().exchange(sent_array_list, recv_array_list); messenger().exchange(sent_array_list, recv_array_list);
} }
#ifdef PASTIS_HAS_MPI
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<char>() {return MPI_CHAR; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<int8_t>() {return MPI_INT8_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<int16_t>() {return MPI_INT16_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<int32_t>() {return MPI_INT32_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<int64_t>() {return MPI_INT64_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<uint8_t>() {return MPI_UINT8_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<uint16_t>() {return MPI_UINT16_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<uint32_t>() {return MPI_UINT32_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<uint64_t>() {return MPI_UINT64_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<float>() {return MPI_FLOAT; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<double>() {return MPI_DOUBLE; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<long double>() {return MPI_LONG_DOUBLE; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<wchar_t>() {return MPI_WCHAR; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::Helper::mpiType<bool>() {return MPI_CXX_BOOL; }
#endif // PASTIS_HAS_MPI
#endif // MESSENGER_HPP #endif // MESSENGER_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment