Skip to content
Snippets Groups Projects
Commit cd07ecf6 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add generic reduce{min,max} capability

parent 344c5591
Branches
Tags
1 merge request!11Feature/mpi
......@@ -87,7 +87,7 @@ class Messenger
Array<DataType> gather) const
{
static_assert(std::is_arithmetic_v<DataType>);
Assert(gather.size() == m_size);
Assert(gather.size() == m_size); // LCOV_EXCL_LINE
#ifdef PASTIS_HAS_MPI
MPI_Datatype mpi_datatype
......@@ -109,7 +109,7 @@ class Messenger
void _allGather(const SendArrayType<SendT...>& data_array,
RecvArrayType<RecvT...> gather_array) const
{
Assert(gather_array.size() == data_array.size()*m_size);
Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE
using SendDataType = typename SendArrayType<SendT...>::data_type;
using RecvDataType = typename RecvArrayType<RecvT...>::data_type;
......@@ -170,8 +170,8 @@ class Messenger
static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
static_assert(std::is_arithmetic_v<SendDataType>);
Assert((sent_array.size() % m_size) == 0);
Assert(sent_array.size() == recv_array.size());
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE
const int count = sent_array.size()/m_size;
......@@ -261,7 +261,7 @@ class Messenger
PASTIS_INLINE
static Messenger& getInstance()
{
Assert(m_instance != nullptr);
Assert(m_instance != nullptr); // LCOV_EXCL_LINE
return *m_instance;
}
......@@ -279,6 +279,44 @@ class Messenger
void barrier() const;
template <typename DataType>
DataType allReduceMin(const DataType& data) const
{
#ifdef PASTIS_HAS_MPI
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
DataType min_data = data;
MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);
return min_data;
#else // PASTIS_HAS_MPI
return data;
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
DataType allReduceMax(const DataType& data) const
{
#ifdef PASTIS_HAS_MPI
static_assert(not std::is_const_v<DataType>);
static_assert(std::is_arithmetic_v<DataType>);
MPI_Datatype mpi_datatype
= Messenger::helper::mpiType<DataType>();
DataType max_data = data;
MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);
return max_data;
#else // PASTIS_HAS_MPI
return data;
#endif // PASTIS_HAS_MPI
}
template <typename DataType>
PASTIS_INLINE
Array<DataType>
......@@ -332,10 +370,16 @@ class Messenger
Array<std::remove_const_t<SendDataType>>
allToAll(const Array<SendDataType>& sent_array) const
{
Assert((sent_array.size() % m_size) == 0);
using DataType = std::remove_const_t<SendDataType>;
#ifndef NDEBUG
const size_t min_size = allReduceMin(sent_array.size());
const size_t max_size = allReduceMax(sent_array.size());
Assert(max_size == min_size); // LCOV_EXCL_LINE
#endif // NDEBUG
Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
using DataType = std::remove_const_t<SendDataType>;
Array<DataType> recv_array(sent_array.size());
if constexpr(std::is_arithmetic_v<DataType>) {
_allToAll(sent_array, recv_array);
} else if constexpr(std::is_trivial_v<DataType>) {
......@@ -383,14 +427,14 @@ class Messenger
int size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size);
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
_broadcast_array(array, root_rank);
} else if constexpr(std::is_trivial_v<DataType>) {
int size = array.size();
_broadcast_value(size, root_rank);
if (m_rank != root_rank) {
array = Array<DataType>(size);
array = Array<DataType>(size); // LCOV_EXCL_LINE
}
using CastType = helper::split_cast_t<DataType>;
......@@ -455,6 +499,20 @@ void barrier()
return messenger().barrier();
}
template <typename DataType>
PASTIS_INLINE
DataType allReduceMax(const DataType& data)
{
return messenger().allReduceMax(data);
}
template <typename DataType>
PASTIS_INLINE
DataType allReduceMin(const DataType& data)
{
return messenger().allReduceMin(data);
}
template <typename DataType>
PASTIS_INLINE
Array<DataType>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment