Skip to content
Snippets Groups Projects
Commit 55a62ac0 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add allGather function for single value (ie. not array)

parent b5f563a1
No related branches found
No related tags found
1 merge request!11Feature/mpi
...@@ -58,10 +58,29 @@ void Messenger::barrier() const ...@@ -58,10 +58,29 @@ void Messenger::barrier() const
#endif // PASTIS_HAS_MPI #endif // PASTIS_HAS_MPI
} }
Array<int> Messenger::
_allGather(int& data) const
{
Array<int> gather(m_size);
#ifdef PASTIS_HAS_MPI
MPI_Allgather(&data, 1, MPI_INT, &(gather[0]), 1, MPI_INT,
MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
gather[0] = data;
#endif // PASTIS_HAS_MPI
return gather;
}
int Messenger:: int Messenger::
_broadcast(int& data, int root_rank) const _broadcast(int& data, int root_rank) const
{ {
#ifdef PASTIS_HAS_MPI
MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD); MPI_Bcast(&data, 1, MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
return data; return data;
} }
...@@ -69,11 +88,13 @@ _broadcast(int& data, int root_rank) const ...@@ -69,11 +88,13 @@ _broadcast(int& data, int root_rank) const
Array<int> Messenger:: Array<int> Messenger::
_broadcast(Array<int>& array, int root_rank) const _broadcast(Array<int>& array, int root_rank) const
{ {
#ifdef PASTIS_HAS_MPI
int size = array.size(); int size = array.size();
_broadcast(size, root_rank); _broadcast(size, root_rank);
if (commRank() != root_rank) { if (commRank() != root_rank) {
array = Array<int>(size); array = Array<int>(size);
} }
MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD); MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
return array; return array;
} }
...@@ -15,6 +15,8 @@ class Messenger ...@@ -15,6 +15,8 @@ class Messenger
int m_rank{0}; int m_rank{0};
int m_size{1}; int m_size{1};
Array<int> _allGather(int& data) const;
int _broadcast(int& data, int root_rank) const; int _broadcast(int& data, int root_rank) const;
Array<int> _broadcast(Array<int>& array, int root_rank) const; Array<int> _broadcast(Array<int>& array, int root_rank) const;
...@@ -43,6 +45,18 @@ class Messenger ...@@ -43,6 +45,18 @@ class Messenger
void barrier() const; void barrier() const;
template <typename DataType>
PASTIS_INLINE
Array<DataType> allGather(const DataType& data) const
{
if constexpr(std::is_same<DataType, int>()) {
int int_data = data;
return _allGather(int_data);
} else {
static_assert(std::is_same<DataType, int>(), "unexpected type of data");
}
}
template <typename DataType> template <typename DataType>
PASTIS_INLINE PASTIS_INLINE
DataType broadcast(const DataType& data, int root_rank) const DataType broadcast(const DataType& data, int root_rank) const
...@@ -99,16 +113,23 @@ void barrier() ...@@ -99,16 +113,23 @@ void barrier()
template <typename DataType> template <typename DataType>
PASTIS_INLINE PASTIS_INLINE
Array<DataType> broadcast(const Array<DataType>& array, int root_rank) DataType broadcast(const DataType& data, int root_rank)
{ {
return messenger().broadcast(array, root_rank); return messenger().broadcast(data, root_rank);
} }
template <typename DataType> template <typename DataType>
PASTIS_INLINE PASTIS_INLINE
DataType broadcast(const DataType& data, int root_rank) Array<DataType> allGather(const DataType& data)
{ {
return messenger().broadcast(data, root_rank); return messenger().allGather(data);
}
template <typename DataType>
PASTIS_INLINE
Array<DataType> broadcast(const Array<DataType>& array, int root_rank)
{
return messenger().broadcast(array, root_rank);
} }
#endif // MESSENGER_HPP #endif // MESSENGER_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment