Select Git revision
StencilBuilder.cpp
Messenger.hpp 19.52 KiB
#ifndef MESSENGER_HPP
#define MESSENGER_HPP
#include <PastisMacros.hpp>
#include <PastisAssert.hpp>
#include <Array.hpp>
#include <CastArray.hpp>
#include <ArrayUtils.hpp>
#include <type_traits>
#include <vector>
#include <pastis_config.hpp>
#ifdef PASTIS_HAS_MPI
#include <mpi.h>
#endif // PASTIS_HAS_MPI
namespace parallel
{
class Messenger
{
private:
struct helper
{
#ifdef PASTIS_HAS_MPI
template<typename DataType>
static PASTIS_INLINE
MPI_Datatype mpiType()
{
if constexpr (std::is_const_v<DataType>) {
return mpiType<std::remove_const_t<DataType>>();
} else {
static_assert(std::is_arithmetic_v<DataType>,
"Unexpected arithmetic type! Should not occur!");
static_assert(not std::is_arithmetic_v<DataType>,
"MPI_Datatype are only defined for arithmetic types!");
return MPI_Datatype();
}
}
#endif // PASTIS_HAS_MPI
private:
template <typename T,
typename Allowed = void>
struct split_cast {};
template <typename T>
struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>> {
using type = int64_t;
static_assert(not(sizeof(T) % sizeof(int64_t)));
};
template <typename T>
struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int32_t))
and(sizeof(T) % sizeof(int64_t))>> {
using type = int32_t;
static_assert(not(sizeof(T) % sizeof(int32_t)));
};
template <typename T>
struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int16_t))
and(sizeof(T) % sizeof(int32_t))
and(sizeof(T) % sizeof(int64_t))>> {
using type = int16_t;
static_assert(not(sizeof(T) % sizeof(int16_t)));
};
template <typename T>
struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int8_t))
and(sizeof(T) % sizeof(int16_t))
and(sizeof(T) % sizeof(int32_t))
and(sizeof(T) % sizeof(int64_t))>> {
using type = int8_t;
static_assert(not(sizeof(T) % sizeof(int8_t)));
};
public:
template <typename T>
using split_cast_t = typename split_cast<T>::type;
};
static Messenger* m_instance;
Messenger(int& argc, char* argv[]);
size_t m_rank{0};
size_t m_size{1};
template <typename DataType>
void _allGather(const DataType& data,
Array<DataType> gather) const
{
static_assert(std::is_arithmetic_v<DataType>);
Assert(gather.size() == m_size); // LCOV_EXCL_LINE
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
MPI_Allgather(&data, 1, mpi_datatype,
&(gather[0]), 1, mpi_datatype,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
gather[0] = data;
#endif // PASTIS_HAS_MPI
}
template <template <typename ...SendT> typename SendArrayType,
template <typename ...RecvT> typename RecvArrayType,
typename ...SendT, typename ...RecvT>
void _allGather(const SendArrayType<SendT...>& data_array,
RecvArrayType<RecvT...> gather_array) const
{
Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<RecvDataType>();
MPI_Allgather(&(data_array[0]), data_array.size(), mpi_datatype,
&(gather_array[0]), data_array.size(), mpi_datatype,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
value_copy(data_array, gather_array);
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
void _broadcast_value(DataType& data, const size_t& root_rank) const
{
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
MPI_Bcast(&data, 1, mpi_datatype, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
}
template <typename ArrayType>
void _broadcast_array(ArrayType& array, const size_t& root_rank) const
{
using DataType = typename ArrayType::data_type;
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
}
template <template <typename ...SendT> typename SendArrayType,
template <typename ...RecvT> typename RecvArrayType,
typename ...SendT, typename ...RecvT>
RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array,
RecvArrayType<RecvT...>& recv_array) const
{
#ifdef PASTIS_HAS_MPI
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE
const size_t count = sent_array.size()/m_size;
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<SendDataType>();
MPI_Alltoall(&(sent_array[0]), count, mpi_datatype,
&(recv_array[0]), count, mpi_datatype,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
value_copy(sent_array, recv_array);
#endif // PASTIS_HAS_MPI
return recv_array;
}
template <template <typename ...SendT> typename SendArrayType,
template <typename ...RecvT> typename RecvArrayType,
typename ...SendT, typename ...RecvT>
void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list,
std::vector<RecvArrayType<RecvT...>>& recv_array_list) const
{
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PASTIS_HAS_MPI
std::vector<MPI_Request> request_list;
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<SendDataType>();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const auto sent_array = sent_array_list[i_send];
if (sent_array.size()>0) {
MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
auto recv_array = recv_array_list[i_recv];
if (recv_array.size()>0) {
MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
request_list.push_back(request);
}
}
std::vector<MPI_Status> status_list(request_list.size());
if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
// LCOV_EXCL_START
std::cerr << "Communication error!\n";
std::exit(1);
// LCOV_EXCL_STOP
}
#else // PASTIS_HAS_MPI
Assert(sent_array_list.size() == 1);
Assert(recv_array_list.size() == 1);
value_copy(sent_array_list[0], recv_array_list[0]);
#endif // PASTIS_HAS_MPI
}
template <typename DataType,
typename CastDataType>
void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list,
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
{
std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list;
for (size_t i=0; i<sent_array_list.size(); ++i) {
sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i]));
}
using MutableDataType = std::remove_const_t<DataType>;
std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list;
for (size_t i=0; i<sent_array_list.size(); ++i) {
recv_cast_array_list.emplace_back(recv_array_list[i]);
}
_exchange(sent_cast_array_list, recv_cast_array_list);
}
public:
static void create(int& argc, char* argv[]);
static void destroy();
PASTIS_INLINE
static Messenger& getInstance()
{
Assert(m_instance != nullptr); // LCOV_EXCL_LINE
return *m_instance;
}
PASTIS_INLINE
const size_t& rank() const
{
return m_rank;
}
PASTIS_INLINE
const size_t& size() const
{
return m_size;
}
void barrier() const;
template <typename DataType>
DataType allReduceMin(const DataType& data) const
{
#ifdef PASTIS_HAS_MPI
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
DataType min_data = data;
MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);
return min_data;
#else // PASTIS_HAS_MPI
return data;
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
DataType allReduceMax(const DataType& data) const
{
#ifdef PASTIS_HAS_MPI
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);
return max_data;
#else // PASTIS_HAS_MPI
return data;
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
PASTIS_INLINE
Array<DataType>
allGather(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
Array<DataType> gather_array(m_size);
if constexpr(std::is_arithmetic_v<DataType>) {
_allGather(data, gather_array);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
CastArray cast_value_array = cast_value_to<const CastType>::from(data);
CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);
_allGather(cast_value_array, cast_gather_array);
} else {
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array) const
{
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array(m_size*array.size());
if constexpr(std::is_arithmetic_v<DataType>) {
_allGather(array, gather_array);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_allGather(cast_array, cast_gather_array);
} else {
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename SendDataType>
PASTIS_INLINE
Array<std::remove_const_t<SendDataType>>
allToAll(const Array<SendDataType>& sent_array) const
{
#ifndef NDEBUG
const size_t min_size = allReduceMin(sent_array.size());
const size_t max_size = allReduceMax(sent_array.size());
Assert(max_size == min_size); // LCOV_EXCL_LINE
#endif // NDEBUG
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
using DataType = std::remove_const_t<SendDataType>;
Array<DataType> recv_array(sent_array.size());
if constexpr(std::is_arithmetic_v<DataType>) {
_allToAll(sent_array, recv_array);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
auto send_cast_array = cast_array_to<const CastType>::from(sent_array);
auto recv_cast_array = cast_array_to<CastType>::from(recv_array);
_allToAll(send_cast_array, recv_cast_array);
} else {
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return recv_array;
}
template <typename DataType>
PASTIS_INLINE
void broadcast(DataType& data, const size_t& root_rank) const
{
static_assert(not std::is_const_v<DataType>,
"cannot broadcast const data");
if constexpr(std::is_arithmetic_v<DataType>) {
_broadcast_value(data, root_rank);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
if constexpr(sizeof(CastType) == sizeof(DataType)) {
CastType& cast_data = reinterpret_cast<CastType&>(data);
_broadcast_value(cast_data, root_rank);
} else {
CastArray cast_array = cast_value_to<CastType>::from(data);
_broadcast_array(cast_array, root_rank);
}
} else {
static_assert(std::is_trivial_v<DataType>,
"unexpected non trivial type of data");
}
}
template <typename DataType>
PASTIS_INLINE
void broadcast(Array<DataType>& array,
const size_t& root_rank) const
{
static_assert(not std::is_const_v<DataType>,
"cannot broadcast array of const");
if constexpr(std::is_arithmetic_v<DataType>) {
size_t size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
_broadcast_array(array, root_rank);
} else if constexpr(std::is_trivial_v<DataType>) {
size_t size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
using CastType = helper::split_cast_t<DataType>;
auto cast_array = cast_array_to<CastType>::from(array);
_broadcast_array(cast_array, root_rank);
} else{
static_assert(std::is_trivial_v<DataType>,
"unexpected non trivial type of data");
}
}
template <typename SendDataType,
typename RecvDataType>
PASTIS_INLINE
void exchange(const std::vector<Array<SendDataType>>& send_array_list,
std::vector<Array<RecvDataType>>& recv_array_list) const
{
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
"send and receive data type do not match");
static_assert(not std::is_const_v<RecvDataType>,
"receive data type cannot be const");
using DataType = std::remove_const_t<SendDataType>;
Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE
Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE
#ifndef NDEBUG
Array<size_t> send_size(m_size);
for (size_t i=0; i<m_size; ++i) {
send_size[i] = send_array_list[i].size();
}
Array<size_t> recv_size = allToAll(send_size);
bool correct_sizes = true;
for (size_t i=0; i<m_size; ++i) {
correct_sizes &= (recv_size[i] == recv_array_list[i].size());
}
Assert(correct_sizes); // LCOV_EXCL_LINE
#endif // NDEBUG
if constexpr(std::is_arithmetic_v<DataType>) {
_exchange(send_array_list, recv_array_list);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
_exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list);
} else {
static_assert(std::is_trivial_v<RecvDataType>,
"unexpected non trivial type of data");
}
}
Messenger(const Messenger&) = delete;
~Messenger();
};
PASTIS_INLINE
const Messenger& messenger()
{
return Messenger::getInstance();
}
PASTIS_INLINE
const size_t& rank()
{
return messenger().rank();
}
PASTIS_INLINE
const size_t& size()
{
return messenger().size();
}
PASTIS_INLINE
void barrier()
{
return messenger().barrier();
}
template <typename DataType>
PASTIS_INLINE
DataType allReduceMax(const DataType& data)
{
return messenger().allReduceMax(data);
}
template <typename DataType>
PASTIS_INLINE
DataType allReduceMin(const DataType& data)
{
return messenger().allReduceMin(data);
}
template <typename DataType>
PASTIS_INLINE
Array<DataType>
allGather(const DataType& data)
{
return messenger().allGather(data);
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array)
{
return messenger().allGather(array);
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allToAll(const Array<DataType>& array)
{
return messenger().allToAll(array);
}
template <typename DataType>
PASTIS_INLINE
void broadcast(DataType& data, const size_t& root_rank)
{
messenger().broadcast(data, root_rank);
}
template <typename DataType>
PASTIS_INLINE
void broadcast(Array<DataType>& array, const size_t& root_rank)
{
messenger().broadcast(array, root_rank);
}
template <typename SendDataType,
typename RecvDataType>
PASTIS_INLINE
void exchange(const std::vector<Array<SendDataType>>& sent_array_list,
std::vector<Array<RecvDataType>>& recv_array_list)
{
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
"send and receive data type do not match");
static_assert(not std::is_const_v<RecvDataType>,
"receive data type cannot be const");
messenger().exchange(sent_array_list, recv_array_list);
}
#ifdef PASTIS_HAS_MPI
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<char>() {return MPI_CHAR; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int8_t>() {return MPI_INT8_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int16_t>() {return MPI_INT16_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int32_t>() {return MPI_INT32_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int64_t>() {return MPI_INT64_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint8_t>() {return MPI_UINT8_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint16_t>() {return MPI_UINT16_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint32_t>() {return MPI_UINT32_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint64_t>() {return MPI_UINT64_T; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<float>() {return MPI_FLOAT; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<double>() {return MPI_DOUBLE; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<long double>() {return MPI_LONG_DOUBLE; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<wchar_t>() {return MPI_WCHAR; }
template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<bool>() {return MPI_CXX_BOOL; }
#endif // PASTIS_HAS_MPI
} // namespace parallel
#endif // MESSENGER_HPP