Select Git revision
Messenger.hpp
-
Stéphane Del Pino authoredStéphane Del Pino authored
Messenger.hpp 32.88 KiB
#ifndef MESSENGER_HPP
#define MESSENGER_HPP
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/Array.hpp>
#include <utils/CastArray.hpp>
#include <type_traits>
#include <vector>
#include <utils/pugs_config.hpp>
#ifdef PUGS_HAS_MPI
#include <mpi.h>
#endif // PUGS_HAS_MPI
#include <utils/Exceptions.hpp>
#include <utils/PugsTraits.hpp>
namespace parallel
{
class Messenger
{
public:
struct helper
{
#ifdef PUGS_HAS_MPI
template <typename DataType>
static PUGS_INLINE MPI_Datatype
mpiType()
{
if constexpr (std::is_const_v<DataType>) {
return mpiType<std::remove_const_t<DataType>>();
} else {
static_assert(std::is_arithmetic_v<DataType>, "Unexpected arithmetic type! Should not occur!");
static_assert(is_false_v<DataType>, "MPI_Datatype are only defined for arithmetic types!");
return MPI_Datatype();
}
}
#endif // PUGS_HAS_MPI
private:
template <typename T, typename Allowed = void>
struct split_cast
{
};
template <typename T>
struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>>
{
using type = int64_t;
static_assert(not(sizeof(T) % sizeof(int64_t)));
};
template <typename T>
struct split_cast<T, std::enable_if_t<not(sizeof(T) % sizeof(int32_t)) and (sizeof(T) % sizeof(int64_t))>>
{
using type = int32_t;
static_assert(not(sizeof(T) % sizeof(int32_t)));
};
template <typename T>
struct split_cast<T,
std::enable_if_t<not(sizeof(T) % sizeof(int16_t)) and (sizeof(T) % sizeof(int32_t)) and
(sizeof(T) % sizeof(int64_t))>>
{
using type = int16_t;
static_assert(not(sizeof(T) % sizeof(int16_t)));
};
template <typename T>
struct split_cast<T,
std::enable_if_t<not(sizeof(T) % sizeof(int8_t)) and (sizeof(T) % sizeof(int16_t)) and
(sizeof(T) % sizeof(int32_t)) and (sizeof(T) % sizeof(int64_t))>>
{
using type = int8_t;
static_assert(not(sizeof(T) % sizeof(int8_t)));
};
public:
template <typename T>
using split_cast_t = typename split_cast<T>::type;
};
static Messenger* m_instance;
Messenger(int& argc, char* argv[], bool parallel_output);
#ifdef PUGS_HAS_MPI
MPI_Comm m_pugs_comm_world = MPI_COMM_WORLD;
#endif // PUGS_HAS_MPI
size_t m_rank{0};
size_t m_size{1};
// Rank and size in the whole MPI_COMM_WORLD of the process
size_t m_global_rank{0};
size_t m_global_size{1};
template <typename DataType>
void
_gather(const DataType& data, Array<DataType> gather, size_t rank) const
{
static_assert(std::is_arithmetic_v<DataType>);
Assert(gather.size() == m_size * (rank == m_rank)); // LCOV_EXCL_LINE
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
auto gather_address = (gather.size() > 0) ? &(gather[0]) : nullptr;
MPI_Gather(&data, 1, mpi_datatype, gather_address, 1, mpi_datatype, rank, m_pugs_comm_world);
#else // PUGS_HAS_MPI
gather[0] = data;
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_gather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array, size_t rank) const
{
Assert(gather_array.size() == data_array.size() * m_size * (rank == m_rank)); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>();
auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr;
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Gather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype, rank,
m_pugs_comm_world);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_gatherVariable(const SendArrayType<SendT...>& data_array,
RecvArrayType<RecvT...> gather_array,
Array<int> sizes_array,
size_t rank) const
{
Assert(gather_array.size() - sum(sizes_array) * (rank == m_rank) == 0); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>();
Array<int> start_positions{sizes_array.size()};
if (start_positions.size() > 0) {
start_positions[0] = 0;
for (size_t i = 1; i < start_positions.size(); ++i) {
start_positions[i] = start_positions[i - 1] + sizes_array[i - 1];
}
}
auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr;
auto sizes_address = (sizes_array.size() > 0) ? &(sizes_array[0]) : nullptr;
auto positions_address = (start_positions.size() > 0) ? &(start_positions[0]) : nullptr;
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Gatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
mpi_datatype, rank, m_pugs_comm_world);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
}
template <typename DataType>
void
_allGather(const DataType& data, Array<DataType> gather) const
{
static_assert(std::is_arithmetic_v<DataType>);
Assert(gather.size() == m_size); // LCOV_EXCL_LINE
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Allgather(&data, 1, mpi_datatype, &(gather[0]), 1, mpi_datatype, m_pugs_comm_world);
#else // PUGS_HAS_MPI
gather[0] = data;
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_allGather(const SendArrayType<SendT...>& data_array, RecvArrayType<RecvT...> gather_array) const
{
Assert(gather_array.size() == data_array.size() * m_size); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>();
auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr;
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Allgather(data_address, data_array.size(), mpi_datatype, gather_address, data_array.size(), mpi_datatype,
m_pugs_comm_world);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_allGatherVariable(const SendArrayType<SendT...>& data_array,
RecvArrayType<RecvT...> gather_array,
Array<int> sizes_array) const
{
Assert(gather_array.size() == static_cast<size_t>(sum(sizes_array))); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<RecvDataType>();
Array<int> start_positions{sizes_array.size()};
if (start_positions.size() > 0) {
start_positions[0] = 0;
for (size_t i = 1; i < start_positions.size(); ++i) {
start_positions[i] = start_positions[i - 1] + sizes_array[i - 1];
}
}
auto data_address = (data_array.size() > 0) ? &(data_array[0]) : nullptr;
auto sizes_address = (sizes_array.size() > 0) ? &(sizes_array[0]) : nullptr;
auto positions_address = (start_positions.size() > 0) ? &(start_positions[0]) : nullptr;
auto gather_address = (gather_array.size() > 0) ? &(gather_array[0]) : nullptr;
MPI_Allgatherv(data_address, data_array.size(), mpi_datatype, gather_address, sizes_address, positions_address,
mpi_datatype, m_pugs_comm_world);
#else // PUGS_HAS_MPI
copy_to(data_array, gather_array);
#endif // PUGS_HAS_MPI
}
template <typename DataType>
void
_broadcast_value([[maybe_unused]] DataType& data, [[maybe_unused]] size_t root_rank) const
{
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Bcast(&data, 1, mpi_datatype, root_rank, m_pugs_comm_world);
#endif // PUGS_HAS_MPI
}
template <typename ArrayType>
void
_broadcast_array([[maybe_unused]] ArrayType& array, [[maybe_unused]] size_t root_rank) const
{
using DataType = typename ArrayType::data_type;
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
#ifdef PUGS_HAS_MPI
auto array_address = (array.size() > 0) ? &(array[0]) : nullptr;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
MPI_Bcast(array_address, array.size(), mpi_datatype, root_rank, m_pugs_comm_world);
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_allToAll(const SendArrayType<SendT...>& sent_array, RecvArrayType<RecvT...>& recv_array) const
{
#ifdef PUGS_HAS_MPI
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE
const size_t count = sent_array.size() / m_size;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>();
auto sent_address = (sent_array.size() > 0) ? &(sent_array[0]) : nullptr;
auto recv_address = (recv_array.size() > 0) ? &(recv_array[0]) : nullptr;
MPI_Alltoall(sent_address, count, mpi_datatype, recv_address, count, mpi_datatype, m_pugs_comm_world);
#else // PUGS_HAS_MPI
copy_to(sent_array, recv_array);
#endif // PUGS_HAS_MPI
}
template <template <typename... SendT> typename SendArrayType,
template <typename... RecvT>
typename RecvArrayType,
typename... SendT,
typename... RecvT>
void
_exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list,
std::vector<RecvArrayType<RecvT...>>& recv_array_list) const
{
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PUGS_HAS_MPI
std::vector<MPI_Request> request_list;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>();
for (size_t i_send = 0; i_send < sent_array_list.size(); ++i_send) {
const auto sent_array = sent_array_list[i_send];
if (sent_array.size() > 0) {
MPI_Request request;
MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, m_pugs_comm_world, &request);
request_list.push_back(request);
}
}
for (size_t i_recv = 0; i_recv < recv_array_list.size(); ++i_recv) {
auto recv_array = recv_array_list[i_recv];
if (recv_array.size() > 0) {
MPI_Request request;
MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, m_pugs_comm_world, &request);
request_list.push_back(request);
}
}
if (request_list.size() > 0) {
std::vector<MPI_Status> status_list(request_list.size());
if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
// LCOV_EXCL_START
throw NormalError("Communication error");
// LCOV_EXCL_STOP
}
}
#else // PUGS_HAS_MPI
Assert(sent_array_list.size() == 1);
Assert(recv_array_list.size() == 1);
copy_to(sent_array_list[0], recv_array_list[0]);
#endif // PUGS_HAS_MPI
}
template <typename DataType, typename CastDataType>
void
_exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list,
std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
{
std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list;
for (size_t i = 0; i < sent_array_list.size(); ++i) {
sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i]));
}
using MutableDataType = std::remove_const_t<DataType>;
std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list;
for (size_t i = 0; i < sent_array_list.size(); ++i) {
recv_cast_array_list.emplace_back(recv_array_list[i]);
}
_exchange(sent_cast_array_list, recv_cast_array_list);
}
public:
static void create(int& argc, char* argv[], bool parallel_output = false);
static void destroy();
PUGS_INLINE
static Messenger&
getInstance()
{
Assert(m_instance != nullptr); // LCOV_EXCL_LINE
return *m_instance;
}
PUGS_INLINE
const size_t&
rank() const
{
return m_rank;
}
PUGS_INLINE
const size_t&
size() const
{
return m_size;
}
// The global rank is the rank in the whole MPI_COMM_WORLD, one
// generally needs rank() for classical parallelism
PUGS_INLINE
const size_t&
globalRank() const
{
return m_global_rank;
}
// The global size is the size in the whole MPI_COMM_WORLD, one
// generally needs size() for classical parallelism
PUGS_INLINE
const size_t&
globalSize() const
{
return m_global_size;
}
#ifdef PUGS_HAS_MPI
PUGS_INLINE
const MPI_Comm&
comm() const
{
return m_pugs_comm_world;
}
#endif // PUGS_HAS_MPI
void barrier() const;
template <typename DataType>
DataType
allReduceMin(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
static_assert(not std::is_same_v<DataType, bool>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType min_data = data;
MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, m_pugs_comm_world);
return min_data;
#else // PUGS_HAS_MPI
return data;
#endif // PUGS_HAS_MPI
}
template <typename DataType>
DataType
allReduceAnd(const DataType& data) const
{
static_assert(std::is_same_v<DataType, bool>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LAND, m_pugs_comm_world);
return max_data;
#else // PUGS_HAS_MPI
return data;
#endif // PUGS_HAS_MPI
}
template <typename DataType>
DataType
allReduceOr(const DataType& data) const
{
static_assert(std::is_same_v<DataType, bool>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_LOR, m_pugs_comm_world);
return max_data;
#else // PUGS_HAS_MPI
return data;
#endif // PUGS_HAS_MPI
}
template <typename DataType>
DataType
allReduceMax(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
static_assert(not std::is_same_v<DataType, bool>);
#ifdef PUGS_HAS_MPI
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, m_pugs_comm_world);
return max_data;
#else // PUGS_HAS_MPI
return data;
#endif // PUGS_HAS_MPI
}
template <typename DataType>
DataType
allReduceSum(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
static_assert(not std::is_same_v<DataType, bool>);
#ifdef PUGS_HAS_MPI
if constexpr (std::is_arithmetic_v<DataType>) {
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<DataType>();
DataType data_sum = data;
MPI_Allreduce(&data, &data_sum, 1, mpi_datatype, MPI_SUM, m_pugs_comm_world);
return data_sum;
} else if constexpr (is_trivially_castable<DataType>) {
using InnerDataType = typename DataType::data_type;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<InnerDataType>();
const int size = sizeof(DataType) / sizeof(InnerDataType);
DataType data_sum = data;
MPI_Allreduce(&data, &data_sum, size, mpi_datatype, MPI_SUM, m_pugs_comm_world);
return data_sum;
} else {
throw UnexpectedError("invalid data type for reduce sum");
}
#else // PUGS_HAS_MPI
return data;
#endif // PUGS_HAS_MPI
}
template <typename DataType>
PUGS_INLINE Array<DataType>
gather(const DataType& data, size_t rank) const
{
static_assert(not std::is_const_v<DataType>);
Array<DataType> gather_array((m_rank == rank) ? m_size : 0);
if constexpr (std::is_arithmetic_v<DataType>) {
_gather(data, gather_array, rank);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
CastArray cast_value_array = cast_value_to<const CastType>::from(data);
CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);
_gather(cast_value_array, cast_gather_array, rank);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
gather(const Array<DataType>& array, size_t rank) const
{
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array((m_rank == rank) ? (m_size * array.size()) : 0);
if constexpr (std::is_arithmetic_v<DataType>) {
_gather(array, gather_array, rank);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_gather(cast_array, cast_gather_array, rank);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
gatherVariable(const Array<DataType>& array, size_t rank) const
{
int send_size = array.size();
Array<int> sizes_array = gather(send_size, rank);
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array(sum(sizes_array));
if constexpr (std::is_arithmetic_v<DataType>) {
_gatherVariable(array, gather_array, sizes_array, rank);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
int size_ratio = sizeof(DataType) / sizeof(CastType);
for (size_t i = 0; i < sizes_array.size(); ++i) {
sizes_array[i] *= size_ratio;
}
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_gatherVariable(cast_array, cast_gather_array, sizes_array, rank);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PUGS_INLINE Array<DataType>
allGather(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
Array<DataType> gather_array(m_size);
if constexpr (std::is_arithmetic_v<DataType>) {
_allGather(data, gather_array);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
CastArray cast_value_array = cast_value_to<const CastType>::from(data);
CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);
_allGather(cast_value_array, cast_gather_array);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array) const
{
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array(m_size * array.size());
if constexpr (std::is_arithmetic_v<DataType>) {
_allGather(array, gather_array);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_allGather(cast_array, cast_gather_array);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
allGatherVariable(const Array<DataType>& array) const
{
int send_size = array.size();
Array<int> sizes_array = allGather(send_size);
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array(sum(sizes_array));
if constexpr (std::is_arithmetic_v<DataType>) {
_allGatherVariable(array, gather_array, sizes_array);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
int size_ratio = sizeof(DataType) / sizeof(CastType);
for (size_t i = 0; i < sizes_array.size(); ++i) {
sizes_array[i] *= size_ratio;
}
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_allGatherVariable(cast_array, cast_gather_array, sizes_array);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename SendDataType>
PUGS_INLINE Array<std::remove_const_t<SendDataType>>
allToAll(const Array<SendDataType>& sent_array) const
{
#ifndef NDEBUG
const size_t min_size = allReduceMin(sent_array.size());
const size_t max_size = allReduceMax(sent_array.size());
Assert(max_size == min_size); // LCOV_EXCL_LINE
#endif // NDEBUG
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
using DataType = std::remove_const_t<SendDataType>;
Array<DataType> recv_array(sent_array.size());
if constexpr (std::is_arithmetic_v<DataType>) {
_allToAll(sent_array, recv_array);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
auto send_cast_array = cast_array_to<const CastType>::from(sent_array);
auto recv_cast_array = cast_array_to<CastType>::from(recv_array);
_allToAll(send_cast_array, recv_cast_array);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
return recv_array;
}
template <typename DataType>
PUGS_INLINE void
broadcast(DataType& data, size_t root_rank) const
{
static_assert(not std::is_const_v<DataType>, "cannot broadcast const data");
if constexpr (std::is_arithmetic_v<DataType>) {
_broadcast_value(data, root_rank);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
if constexpr (sizeof(CastType) == sizeof(DataType)) {
CastType& cast_data = reinterpret_cast<CastType&>(data);
_broadcast_value(cast_data, root_rank);
} else {
CastArray cast_array = cast_value_to<CastType>::from(data);
_broadcast_array(cast_array, root_rank);
}
} else if constexpr (std::is_same_v<std::string, DataType>) {
Array s = convert_to_array(data);
broadcast(s, root_rank);
if (m_rank != root_rank) {
data.resize(s.size());
for (size_t i = 0; i < s.size(); ++i) {
data[i] = s[i];
}
}
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
}
template <typename DataType>
PUGS_INLINE void
broadcast(Array<DataType>& array, size_t root_rank) const
{
static_assert(not std::is_const_v<DataType>, "cannot broadcast array of const");
if constexpr (std::is_arithmetic_v<DataType>) {
size_t size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
_broadcast_array(array, root_rank);
} else if constexpr (is_trivially_castable<DataType>) {
size_t size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
using CastType = helper::split_cast_t<DataType>;
auto cast_array = cast_array_to<CastType>::from(array);
_broadcast_array(cast_array, root_rank);
} else {
static_assert(is_false_v<DataType>, "unexpected type of data");
}
}
template <typename SendDataType, typename RecvDataType>
PUGS_INLINE void
exchange(const std::vector<Array<SendDataType>>& send_array_list,
std::vector<Array<RecvDataType>>& recv_array_list) const
{
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>,
"send and receive data type do not match");
static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const");
using DataType = std::remove_const_t<SendDataType>;
Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE
Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE
#ifndef NDEBUG
Array<size_t> send_size(m_size);
for (size_t i = 0; i < m_size; ++i) {
send_size[i] = send_array_list[i].size();
}
Array<size_t> recv_size = allToAll(send_size);
bool correct_sizes = true;
for (size_t i = 0; i < m_size; ++i) {
correct_sizes &= (recv_size[i] == recv_array_list[i].size());
}
Assert(correct_sizes, "incompatible send/recv messages length"); // LCOV_EXCL_LINE
#endif // NDEBUG
if constexpr (std::is_arithmetic_v<DataType>) {
_exchange(send_array_list, recv_array_list);
} else if constexpr (is_trivially_castable<DataType>) {
using CastType = helper::split_cast_t<DataType>;
_exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list);
} else {
static_assert(is_false_v<RecvDataType>, "unexpected type of data");
}
}
Messenger(const Messenger&) = delete;
~Messenger();
};
PUGS_INLINE
const Messenger&
messenger()
{
return Messenger::getInstance();
}
PUGS_INLINE
const size_t&
rank()
{
return messenger().rank();
}
PUGS_INLINE
const size_t&
size()
{
return messenger().size();
}
PUGS_INLINE
void
barrier()
{
messenger().barrier();
}
template <typename DataType>
PUGS_INLINE DataType
allReduceAnd(const DataType& data)
{
return messenger().allReduceAnd(data);
}
template <typename DataType>
PUGS_INLINE DataType
allReduceOr(const DataType& data)
{
return messenger().allReduceOr(data);
}
template <typename DataType>
PUGS_INLINE DataType
allReduceMax(const DataType& data)
{
return messenger().allReduceMax(data);
}
template <typename DataType>
PUGS_INLINE DataType
allReduceMin(const DataType& data)
{
return messenger().allReduceMin(data);
}
template <typename DataType>
PUGS_INLINE DataType
allReduceSum(const DataType& data)
{
return messenger().allReduceSum(data);
}
template <typename DataType>
PUGS_INLINE Array<DataType>
gather(const DataType& data, size_t rank)
{
return messenger().gather(data, rank);
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
gather(const Array<DataType>& array, size_t rank)
{
return messenger().gather(array, rank);
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
gatherVariable(const Array<DataType>& array, size_t rank)
{
return messenger().gatherVariable(array, rank);
}
template <typename DataType>
PUGS_INLINE Array<DataType>
allGather(const DataType& data)
{
return messenger().allGather(data);
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array)
{
return messenger().allGather(array);
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
allGatherVariable(const Array<DataType>& array)
{
return messenger().allGatherVariable(array);
}
template <typename DataType>
PUGS_INLINE Array<std::remove_const_t<DataType>>
allToAll(const Array<DataType>& array)
{
return messenger().allToAll(array);
}
template <typename DataType>
PUGS_INLINE void
broadcast(DataType& data, size_t root_rank)
{
messenger().broadcast(data, root_rank);
}
template <typename DataType>
PUGS_INLINE void
broadcast(Array<DataType>& array, size_t root_rank)
{
messenger().broadcast(array, root_rank);
}
template <typename SendDataType, typename RecvDataType>
PUGS_INLINE void
exchange(const std::vector<Array<SendDataType>>& sent_array_list, std::vector<Array<RecvDataType>>& recv_array_list)
{
static_assert(std::is_same_v<std::remove_const_t<SendDataType>, RecvDataType>,
"send and receive data type do not match");
static_assert(not std::is_const_v<RecvDataType>, "receive data type cannot be const");
messenger().exchange(sent_array_list, recv_array_list);
}
#ifdef PUGS_HAS_MPI
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<char>()
{
return MPI_CHAR;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<int8_t>()
{
return MPI_INT8_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<int16_t>()
{
return MPI_INT16_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<int32_t>()
{
return MPI_INT32_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<int64_t>()
{
return MPI_INT64_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint8_t>()
{
return MPI_UINT8_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint16_t>()
{
return MPI_UINT16_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint32_t>()
{
return MPI_UINT32_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint64_t>()
{
return MPI_UINT64_T;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<signed long long int>()
{
return MPI_LONG_LONG_INT;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<unsigned long long int>()
{
return MPI_UNSIGNED_LONG_LONG;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<float>()
{
return MPI_FLOAT;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<double>()
{
return MPI_DOUBLE;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<long double>()
{
return MPI_LONG_DOUBLE;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<wchar_t>()
{
return MPI_WCHAR;
}
template <>
PUGS_INLINE MPI_Datatype
Messenger::helper::mpiType<bool>()
{
return MPI_CXX_BOOL;
}
#endif // PUGS_HAS_MPI
} // namespace parallel
#endif // MESSENGER_HPP