Skip to content
Snippets Groups Projects
Commit 068d55c6 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Define allGather of trivial type or of arrays of trivial types

parent 6cbe0d9b
No related branches found
No related tags found
1 merge request!11Feature/mpi
......@@ -51,18 +51,3 @@ void Messenger::barrier() const
MPI_Barrier(MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
}
Array<int> Messenger::
_allGather(int& data) const
{
Array<int> gather(m_size);
#ifdef PASTIS_HAS_MPI
MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
gather[0] = data;
#endif // PASTIS_HAS_MPI
return gather;
}
......@@ -82,15 +82,60 @@ class Messenger
int m_rank{0};
int m_size{1};
Array<int> _allGather(int& data) const;
template <typename DataType>
void _allGather(const DataType& data,
Array<DataType> gather) const
{
static_assert(std::is_arithmetic_v<DataType>);
Assert(gather.size() == m_size);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
MPI_Allgather(&data, 1, mpi_datatype,
&(gather[0]), 1, mpi_datatype,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
gather[0] = data;
#endif // PASTIS_HAS_MPI
}
template <template <typename ...SendT> typename SendArrayType,
template <typename ...RecvT> typename RecvArrayType,
typename ...SendT, typename ...RecvT>
void _allGather(const SendArrayType<SendT...>& data_array,
RecvArrayType<RecvT...> gather_array) const
{
Assert(gather_array.size() == data_array.size()*m_size);
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<RecvDataType>();
MPI_Allgather(&(data_array[0]), data_array.size(), mpi_datatype,
&(gather_array[0]), data_array.size(), mpi_datatype,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
static_assert(false, "NIY");
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
void _broadcast_value(DataType& data, int root_rank) const
{
#ifdef PASTIS_HAS_MPI
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
......@@ -236,14 +281,50 @@ class Messenger
template <typename DataType>
PASTIS_INLINE
Array<DataType> allGather(const DataType& data) const
Array<DataType>
allGather(const DataType& data) const
{
static_assert(not std::is_const_v<DataType>);
Array<DataType> gather_array(m_size);
if constexpr(std::is_arithmetic_v<DataType>) {
_allGather(data, gather_array);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
CastArray cast_value_array = cast_value_to<const CastType>::from(data);
CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);
_allGather(cast_value_array, cast_gather_array);
} else {
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array) const
{
if constexpr(std::is_same<DataType, int>()) {
int int_data = data;
return _allGather(int_data);
using MutableDataType = std::remove_const_t<DataType>;
Array<MutableDataType> gather_array(m_size*array.size());
if constexpr(std::is_arithmetic_v<DataType>) {
_allGather(array, gather_array);
} else if constexpr(std::is_trivial_v<DataType>) {
using CastType = helper::split_cast_t<DataType>;
using MutableCastType = helper::split_cast_t<MutableDataType>;
CastArray cast_array = cast_array_to<CastType>::from(array);
CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);
_allGather(cast_array, cast_gather_array);
} else {
static_assert(std::is_same<DataType, int>(), "unexpected type of data");
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return gather_array;
}
template <typename SendDataType>
......@@ -264,7 +345,7 @@ class Messenger
auto recv_cast_array = cast_array_to<CastType>::from(recv_array);
_allToAll(send_cast_array, recv_cast_array);
} else {
static_assert(std::is_same<DataType, int>(), "unexpected type of data");
static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
}
return recv_array;
}
......@@ -387,11 +468,20 @@ void broadcast(DataType& data, int root_rank)
template <typename DataType>
PASTIS_INLINE
Array<DataType> allGather(const DataType& data)
Array<DataType>
allGather(const DataType& data)
{
return messenger().allGather(data);
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array)
{
return messenger().allGather(array);
}
template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment