Skip to content
Snippets Groups Projects
Commit b27836c9 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Rewrite array broadcast in a generic way (was only defined for int)

parent 21777e5a
No related branches found
No related tags found
1 merge request!11Feature/mpi
...@@ -68,25 +68,10 @@ _allGather(int& data) const ...@@ -68,25 +68,10 @@ _allGather(int& data) const
} }
int Messenger:: int Messenger::
_broadcast(int& data, int root_rank) const _broadcast_value(int& data, int root_rank) const
{ {
#ifdef PASTIS_HAS_MPI #ifdef PASTIS_HAS_MPI
MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI #endif // PASTIS_HAS_MPI
return data; return data;
} }
Array<int> Messenger::
_broadcast(Array<int>& array, int root_rank) const
{
#ifdef PASTIS_HAS_MPI
int size = array.size();
_broadcast(size, root_rank);
if (commRank() != root_rank) {
array = Array<int>(size);
}
MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
return array;
}
...@@ -84,8 +84,21 @@ class Messenger ...@@ -84,8 +84,21 @@ class Messenger
Array<int> _allGather(int& data) const; Array<int> _allGather(int& data) const;
int _broadcast(int& data, int root_rank) const; int _broadcast_value(int& data, int root_rank) const;
Array<int> _broadcast(Array<int>& array, int root_rank) const;
template <typename ArrayType>
void _broadcast_array(ArrayType& array, int root_rank) const
{
using DataType = typename ArrayType::data_type;
static_assert(not std::is_const_v<DataType>);
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
}
template <typename SendArrayType, template <typename SendArrayType,
typename RecvArrayType> typename RecvArrayType>
...@@ -103,7 +116,8 @@ class Messenger ...@@ -103,7 +116,8 @@ class Messenger
const int count = sent_array.size()/m_size; const int count = sent_array.size()/m_size;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<SendDataType>();
MPI_Alltoall(&(sent_array[0]), count, mpi_datatype, MPI_Alltoall(&(sent_array[0]), count, mpi_datatype,
&(recv_array[0]), count, mpi_datatype, &(recv_array[0]), count, mpi_datatype,
...@@ -128,7 +142,8 @@ class Messenger ...@@ -128,7 +142,8 @@ class Messenger
#ifdef PASTIS_HAS_MPI #ifdef PASTIS_HAS_MPI
std::vector<MPI_Request> request_list; std::vector<MPI_Request> request_list;
MPI_Datatype mpi_datatype = Messenger::helper::mpiType<SendDataType>(); MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<SendDataType>();
for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) { for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
const SendArrayType sent_array = sent_array_list[i_send]; const SendArrayType sent_array = sent_array_list[i_send];
...@@ -245,7 +260,7 @@ class Messenger ...@@ -245,7 +260,7 @@ class Messenger
{ {
if constexpr(std::is_same<DataType, int>()) { if constexpr(std::is_same<DataType, int>()) {
int int_data = data; int int_data = data;
return _broadcast(int_data, root_rank); return _broadcast_value(int_data, root_rank);
} else { } else {
static_assert(std::is_same<DataType, int>(), "unexpected type of data"); static_assert(std::is_same<DataType, int>(), "unexpected type of data");
} }
...@@ -255,12 +270,30 @@ class Messenger ...@@ -255,12 +270,30 @@ class Messenger
PASTIS_INLINE PASTIS_INLINE
Array<DataType> broadcast(const Array<DataType>& array, int root_rank) const Array<DataType> broadcast(const Array<DataType>& array, int root_rank) const
{ {
if constexpr(std::is_same<DataType, int>()) { static_assert(not std::is_const_v<DataType>,
Array<int> int_array = array; "cannot broadcast array of const");
return _broadcast(int_array, root_rank); if constexpr(std::is_arithmetic_v<DataType>) {
int size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size);
}
_broadcast_array(array, root_rank);
} else if constexpr(std::is_trivial_v<DataType>) {
int size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size);
}
using CastType = helper::split_cast_t<DataType>;
auto cast_array = cast_array_to<CastType>::from(array);
_broadcast_array(cast_array, root_rank);
} else{ } else{
static_assert(std::is_same<DataType, int>(), "unexpected type of data"); static_assert(std::is_trivial_v<DataType>,
"unexpected non trivial type of data");
} }
return array;
} }
template <typename SendDataType, template <typename SendDataType,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment