Select Git revision
PugsTraits.hpp
ItemValueUtils.hpp 9.77 KiB
#ifndef ITEM_VALUE_UTILS_HPP
#define ITEM_VALUE_UTILS_HPP
#include <utils/Messenger.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/Synchronizer.hpp>
#include <mesh/SynchronizerManager.hpp>
#include <utils/PugsTraits.hpp>
#include <iostream>
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
std::remove_const_t<DataType>
min(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
{
using ItemValueType = ItemValue<DataType, item_type, ConnectivityPtr>;
using ItemIsOwnedType = ItemValue<const bool, item_type>;
using data_type = std::remove_const_t<typename ItemValueType::data_type>;
using index_type = typename ItemValueType::index_type;
static_assert(std::is_arithmetic_v<data_type>, "min cannot be called on non-arithmetic data");
static_assert(not std::is_same_v<data_type, bool>, "min cannot be called on boolean data");
class ItemValueMin
{
private:
const ItemValueType& m_item_value;
const ItemIsOwnedType m_is_owned;
public:
PUGS_INLINE
operator data_type()
{
data_type reduced_value;
parallel_reduce(m_item_value.numberOfItems(), *this, reduced_value);
return reduced_value;
}
PUGS_INLINE
void
operator()(const index_type& i, data_type& data) const
{
if ((m_is_owned[i]) and (m_item_value[i] < data)) {
data = m_item_value[i];
}
}
PUGS_INLINE
void
join(volatile data_type& dst, const volatile data_type& src) const
{
if (src < dst) {
// cannot be reached if initial value is the min
dst = src; // LCOV_EXCL_LINE
}
}
PUGS_INLINE
void
init(data_type& value) const
{
value = std::numeric_limits<data_type>::max();
}
PUGS_INLINE
ItemValueMin(const ItemValueType& item_value)
: m_item_value(item_value), m_is_owned([&](const IConnectivity& connectivity) {
Assert((connectivity.dimension() > 0) and (connectivity.dimension() <= 3),
"unexpected connectivity dimension");
switch (connectivity.dimension()) {
case 1: {
const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity);
return connectivity_1d.isOwned<item_type>();
break;
}
case 2: {
const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity);
return connectivity_2d.isOwned<item_type>();
break;
}
case 3: {
const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity);
return connectivity_3d.isOwned<item_type>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected dimension");
}
// LCOV_EXCL_STOP
}
}(*item_value.connectivity_ptr()))
{
;
}
PUGS_INLINE
~ItemValueMin() = default;
};
const DataType local_min = ItemValueMin{item_value};
return parallel::allReduceMin(local_min);
}
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
std::remove_const_t<DataType>
max(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
{
using ItemValueType = ItemValue<DataType, item_type, ConnectivityPtr>;
using ItemIsOwnedType = ItemValue<const bool, item_type>;
using data_type = std::remove_const_t<typename ItemValueType::data_type>;
using index_type = typename ItemValueType::index_type;
static_assert(std::is_arithmetic_v<data_type>, "max cannot be called on non arithmetic data");
static_assert(not std::is_same_v<data_type, bool>, "max cannot be called on boolean data");
class ItemValueMax
{
private:
const ItemValueType& m_item_value;
const ItemIsOwnedType m_is_owned;
public:
PUGS_INLINE
operator data_type()
{
data_type reduced_value;
parallel_reduce(m_item_value.numberOfItems(), *this, reduced_value);
return reduced_value;
}
PUGS_INLINE
void
operator()(const index_type& i, data_type& data) const
{
if ((m_is_owned[i]) and (m_item_value[i] > data)) {
data = m_item_value[i];
}
}
PUGS_INLINE
void
join(volatile data_type& dst, const volatile data_type& src) const
{
if (src > dst) {
// cannot be reached if initial value is the max
dst = src; // LCOV_EXCL_LINE
}
}
PUGS_INLINE
void
init(data_type& value) const
{
value = std::numeric_limits<data_type>::min();
}
PUGS_INLINE
ItemValueMax(const ItemValueType& item_value)
: m_item_value(item_value), m_is_owned([&](const IConnectivity& connectivity) {
Assert((connectivity.dimension() > 0) and (connectivity.dimension() <= 3),
"unexpected connectivity dimension");
switch (connectivity.dimension()) {
case 1: {
const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity);
return connectivity_1d.isOwned<item_type>();
break;
}
case 2: {
const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity);
return connectivity_2d.isOwned<item_type>();
break;
}
case 3: {
const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity);
return connectivity_3d.isOwned<item_type>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected dimension");
}
// LCOV_EXCL_STOP
}
}(*item_value.connectivity_ptr()))
{
;
}
PUGS_INLINE
~ItemValueMax() = default;
};
const DataType local_max = ItemValueMax{item_value};
return parallel::allReduceMax(local_max);
}
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
std::remove_const_t<DataType>
sum(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
{
using ItemValueType = ItemValue<DataType, item_type, ConnectivityPtr>;
using ItemIsOwnedType = ItemValue<const bool, item_type>;
using data_type = std::remove_const_t<typename ItemValueType::data_type>;
using index_type = typename ItemValueType::index_type;
static_assert(not std::is_same_v<data_type, bool>, "sum cannot be called on boolean data");
class ItemValueSum
{
private:
const ItemValueType& m_item_value;
const ItemIsOwnedType m_is_owned;
public:
PUGS_INLINE
operator data_type()
{
data_type reduced_value;
parallel_reduce(m_item_value.numberOfItems(), *this, reduced_value);
return reduced_value;
}
PUGS_INLINE
void
operator()(const index_type& i, data_type& data) const
{
if (m_is_owned[i]) {
data += m_item_value[i];
}
}
PUGS_INLINE
void
join(volatile data_type& dst, const volatile data_type& src) const
{
dst += src;
}
PUGS_INLINE
void
init(data_type& value) const
{
if constexpr (std::is_arithmetic_v<data_type>) {
value = 0;
} else {
static_assert(is_tiny_vector_v<data_type> or is_tiny_matrix_v<data_type>, "invalid data type");
value = zero;
}
}
PUGS_INLINE
ItemValueSum(const ItemValueType& item_value)
: m_item_value(item_value), m_is_owned([&](const IConnectivity& connectivity) {
Assert((connectivity.dimension() > 0) and (connectivity.dimension() <= 3),
"unexpected connectivity dimension");
switch (connectivity.dimension()) {
case 1: {
const auto& connectivity_1d = static_cast<const Connectivity1D&>(connectivity);
return connectivity_1d.isOwned<item_type>();
break;
}
case 2: {
const auto& connectivity_2d = static_cast<const Connectivity2D&>(connectivity);
return connectivity_2d.isOwned<item_type>();
break;
}
case 3: {
const auto& connectivity_3d = static_cast<const Connectivity3D&>(connectivity);
return connectivity_3d.isOwned<item_type>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected dimension");
}
// LCOV_EXCL_STOP
}
}(*item_value.connectivity_ptr()))
{
;
}
PUGS_INLINE
~ItemValueSum() = default;
};
const DataType local_sum = ItemValueSum{item_value};
return parallel::allReduceSum(local_sum);
}
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
void
synchronize(ItemValue<DataType, item_type, ConnectivityPtr> item_value)
{
static_assert(not std::is_const_v<DataType>, "cannot synchronize ItemValue of const data");
if (parallel::size() > 1) {
auto& manager = SynchronizerManager::instance();
const IConnectivity* connectivity = item_value.connectivity_ptr().get();
Synchronizer& synchronizer = manager.getConnectivitySynchronizer(connectivity);
synchronizer.synchronize(item_value);
}
}
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
bool
isSynchronized(ItemValue<DataType, item_type, ConnectivityPtr> item_value)
{
bool is_synchronized = true;
if (parallel::size() > 1) {
ItemValue<std::remove_const_t<DataType>, item_type> item_value_copy = copy(item_value);
synchronize(item_value_copy);
for (ItemIdT<item_type> item_id = 0; item_id < item_value_copy.numberOfItems(); ++item_id) {
if (item_value_copy[item_id] != item_value[item_id]) {
is_synchronized = false;
break;
}
}
is_synchronized = parallel::allReduceAnd(is_synchronized);
}
return is_synchronized;
}
#endif // ITEM_VALUE_UTILS_HPP