Select Git revision
BoundaryConditionDescriptor.hpp
-
Stéphane Del Pino authoredStéphane Del Pino authored
ReproducibleSumUtils.hpp 25.37 KiB
#ifndef REPRODUCIBLE_SUM_UTILS_HPP
#define REPRODUCIBLE_SUM_UTILS_HPP
#include <utils/PugsUtils.hpp>
#include <utils/Types.hpp>
template <typename DataType>
class Array;
namespace reproducible_sum_utils
{
template <size_t NumberOfBits>
struct IntegerFromBitSize
{
};
template <>
struct IntegerFromBitSize<8>
{
using integer_t = std::int8_t;
};
template <>
struct IntegerFromBitSize<16>
{
using integer_t = std::int16_t;
};
template <>
struct IntegerFromBitSize<32>
{
using integer_t = std::int32_t;
};
template <>
struct IntegerFromBitSize<64>
{
using integer_t = std::int64_t;
};
template <typename DataType>
struct IntegerType
{
using integer_t = typename IntegerFromBitSize<sizeof(DataType) * 8>::integer_t;
};
template <typename DataType>
DataType
ulp(const DataType& x) noexcept(NO_ASSERT)
{
static_assert(std::is_floating_point_v<DataType>, "expecting floating point value");
if (x == 0) {
return std::numeric_limits<DataType>::denorm_min();
}
return std::pow(DataType{2}, std::ilogb(std::abs(x)) - std::numeric_limits<DataType>::digits);
}
template <typename DataType>
DataType
ufp(const DataType& x) noexcept(NO_ASSERT)
{
static_assert(std::is_floating_point_v<DataType>, "expecting floating point value");
return std::pow(DataType{2}, std::ilogb(std::abs(x)));
}
// Useful bits per bin
template <typename DataType>
constexpr inline size_t bin_size = 0;
template <>
constexpr inline size_t bin_size<double> = 40;
template <>
constexpr inline size_t bin_size<float> = 12;
// IEEE 754 epsilon
template <typename DataType>
constexpr inline double bin_epsilon = 0;
template <>
constexpr inline double bin_epsilon<double> = std::numeric_limits<double>::epsilon();
template <>
constexpr inline double bin_epsilon<float> = std::numeric_limits<float>::epsilon();
// number of bins: improves precision
template <typename DataType>
constexpr inline size_t bin_number = 0;
template <>
constexpr inline size_t bin_number<double> = 3;
template <>
constexpr inline size_t bin_number<float> = 4;
// max local sum size to avoid overflow
template <typename DataType>
constexpr inline size_t bin_max_size = 0;
template <>
constexpr inline size_t bin_max_size<double> = 2048;
template <>
constexpr inline size_t bin_max_size<float> = 1024;
struct NoMask
{
PUGS_INLINE bool
operator[](size_t) const
{
return true;
}
};
} // namespace reproducible_sum_utils
template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleScalarSum
{
public:
using DataType = std::decay_t<typename ArrayT::data_type>;
static_assert(std::is_floating_point_v<DataType>);
static constexpr size_t K = reproducible_sum_utils::bin_number<DataType>;
static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
static constexpr size_t W = reproducible_sum_utils::bin_size<DataType>;
struct Bin
{
std::array<DataType, K> S; // sum
std::array<DataType, K> C; // carry
Bin& operator=(const Bin&) = default;
Bin& operator=(Bin&&) = default;
Bin(Bin&&) = default;
Bin(const Bin&) = default;
Bin(ZeroType)
{
for (size_t k = 0; k < K; ++k) {
S[k] = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
C[k] = 0;
}
}
Bin() = default;
~Bin() = default;
};
private:
Bin m_summation_bin;
PUGS_INLINE
static void
_shift(const size_t g, Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = K - 1; k >= g; --k) {
bin.S[k] = bin.S[k - g];
bin.C[k] = bin.C[k - g];
}
for (size_t k = 0; k < std::min(K, g); ++k) {
bin.S[k] = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k]);
bin.C[k] = 0;
}
}
PUGS_INLINE static void
_update(const DataType& m, Bin& bin) noexcept(NO_ASSERT)
{
Assert(m >= 0);
using namespace reproducible_sum_utils;
if (m >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0])) {
const size_t shift = 1 + std::floor(std::log2(m / (std::pow(DataType{2}, W - 1.) * ulp(bin.S[0]))) / W);
_shift(shift, bin);
}
}
PUGS_INLINE
void static _split2(DataType& S, DataType& x) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
union
{
DataType as_DataType;
typename IntegerType<DataType>::integer_t as_integer;
} x_bar;
x_bar.as_DataType = x;
x_bar.as_integer |= 0x1;
const DataType S0 = S;
S += x_bar.as_DataType;
x -= S - S0;
}
PUGS_INLINE static void
_renormalize(Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = 0; k < K; ++k) {
if (bin.S[k] >= 1.75 * ufp(bin.S[k])) {
bin.S[k] -= 0.25 * ufp(bin.S[k]);
bin.C[k] += 1;
} else if (bin.S[k] < 1.25 * ufp(bin.S[k])) {
bin.S[k] += 0.5 * ufp(bin.S[k]);
bin.C[k] -= 2;
} else if (bin.S[k] < 1.5 * ufp(bin.S[k])) {
bin.S[k] += 0.25 * ufp(bin.S[k]);
bin.C[k] -= 1;
}
}
}
public:
static void
addBinTo(Bin& bin, Bin& bin_sum)
{
using namespace reproducible_sum_utils;
DataType ulp_bin = ulp(bin.S[0]);
DataType ulp_sum = ulp(bin_sum.S[0]);
if (ulp_bin < ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
if (shift > 0) {
_shift(shift, bin);
}
} else if (ulp_bin > ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
if (shift > 0) {
_shift(shift, bin_sum);
}
}
for (size_t k = 0; k < K; ++k) {
bin_sum.S[k] += bin.S[k] - 1.5 * ufp(bin.S[k]);
bin_sum.C[k] += bin.C[k];
}
_renormalize(bin_sum);
}
PUGS_INLINE
static DataType
getValue(const Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
DataType value = 0;
for (size_t k = 0; k < K; ++k) {
value += 0.25 * (bin.C[k] - 6) * ufp(bin.S[k]) + bin.S[k];
}
return value;
}
PUGS_INLINE Bin
getSummationBin() const noexcept(NO_ASSERT)
{
return m_summation_bin;
}
PUGS_INLINE DataType
getSum() const noexcept(NO_ASSERT)
{
return getValue(m_summation_bin);
}
ReproducibleScalarSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
{
if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
"when provided, mask must be an array of bool");
}
using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
using TeamMemberT = TeamPolicyT::member_type;
int nx = reproducible_sum_utils::bin_max_size<DataType>;
int ny = std::max(array.size() / nx, 1ul);
const TeamPolicyT policy(ny, Kokkos::AUTO());
Array<DataType> thread_sum(policy.team_size() * policy.league_size());
Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
bin_by_thread.fill(zero);
Array<DataType> local_max(policy.team_size() * policy.league_size());
local_max.fill(0);
parallel_for(
policy, PUGS_LAMBDA(const TeamMemberT& member) {
const int i_team = member.league_rank();
const int i_thread = member.team_rank();
const int thread_id = i_team * member.team_size() + i_thread;
const int league_start = nx * i_team;
const int block_size = [&] {
int size = nx;
if (i_team == member.league_size() - 1) {
size = array.size() - league_start;
}
return size;
}();
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
DataType& m = local_max[thread_id];
DataType abs_value = std::abs(array[league_start + i]);
if (abs_value > m) {
m = abs_value;
}
}
});
_update(local_max[thread_id], bin_by_thread[thread_id]);
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
DataType x = array[nx * i_team + i];
for (size_t k = 0; k < K; ++k) {
_split2(bin_by_thread[thread_id].S[k], x);
};
}
});
_renormalize(bin_by_thread[thread_id]);
});
m_summation_bin = bin_by_thread[0];
for (size_t i = 1; i < bin_by_thread.size(); ++i) {
addBinTo(bin_by_thread[i], m_summation_bin);
}
}
~ReproducibleScalarSum() = default;
};
template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleTinyVectorSum
{
public:
using TinyVectorType = std::decay_t<typename ArrayT::data_type>;
static_assert(is_tiny_vector_v<TinyVectorType>);
using DataType = std::decay_t<typename TinyVectorType::data_type>;
static_assert(std::is_floating_point_v<DataType>);
static constexpr size_t K = reproducible_sum_utils::bin_number<DataType>;
static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
static constexpr size_t W = reproducible_sum_utils::bin_size<DataType>;
struct Bin
{
std::array<TinyVectorType, K> S; // sum
std::array<TinyVectorType, K> C; // carry
Bin& operator=(const Bin&) = default;
Bin& operator=(Bin&&) = default;
Bin(Bin&&) = default;
Bin(const Bin&) = default;
Bin(ZeroType)
{
for (size_t k = 0; k < K; ++k) {
const DataType init_value = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
S[k][i_component] = init_value;
}
C[k] = zero;
}
}
Bin() = default;
~Bin() = default;
};
private:
Bin m_summation_bin;
PUGS_INLINE
static void
_shift(const size_t g, Bin& bin, const size_t& i_component) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = K - 1; k >= g; --k) {
bin.S[k][i_component] = bin.S[k - g][i_component];
bin.C[k][i_component] = bin.C[k - g][i_component];
}
for (size_t k = 0; k < std::min(K, g); ++k) {
bin.S[k][i_component] = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k][i_component]);
bin.C[k][i_component] = 0;
}
}
PUGS_INLINE static void
_update(const TinyVectorType& m, Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
if (m[i_component] >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0][i_component])) {
const size_t shift =
1 + std::floor(std::log2(m[i_component] / (std::pow(DataType{2}, W - 1.) * ulp(bin.S[0][i_component]))) / W);
_shift(shift, bin, i_component);
}
}
}
PUGS_INLINE
void static _split2(TinyVectorType& S, TinyVectorType& x) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
union
{
DataType as_DataType;
typename IntegerType<DataType>::integer_t as_integer;
} x_bar;
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
x_bar.as_DataType = x[i_component];
x_bar.as_integer |= 0x1;
DataType& S_i = S[i_component];
DataType& x_i = x[i_component];
const DataType S0 = S_i;
S_i += x_bar.as_DataType;
x_i -= S_i - S0;
}
}
PUGS_INLINE static void
_renormalize(Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = 0; k < K; ++k) {
TinyVectorType& S_k = bin.S[k];
TinyVectorType& C_k = bin.C[k];
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
DataType& Sk_i = S_k[i_component];
DataType& Ck_i = C_k[i_component];
if (Sk_i >= 1.75 * ufp(Sk_i)) {
Sk_i -= 0.25 * ufp(Sk_i);
Ck_i += 1;
} else if (Sk_i < 1.25 * ufp(Sk_i)) {
Sk_i += 0.5 * ufp(Sk_i);
Ck_i -= 2;
} else if (Sk_i < 1.5 * ufp(Sk_i)) {
Sk_i += 0.25 * ufp(Sk_i);
Ck_i -= 1;
}
}
}
}
public:
static void
addBinTo(Bin& bin, Bin& bin_sum)
{
using namespace reproducible_sum_utils;
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
DataType ulp_bin = ulp(bin.S[0][i_component]);
DataType ulp_sum = ulp(bin_sum.S[0][i_component]);
if (ulp_bin < ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
if (shift > 0) {
_shift(shift, bin, i_component);
}
} else if (ulp_bin > ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
if (shift > 0) {
_shift(shift, bin_sum, i_component);
}
}
for (size_t k = 0; k < K; ++k) {
bin_sum.S[k][i_component] += bin.S[k][i_component] - 1.5 * ufp(bin.S[k][i_component]);
bin_sum.C[k][i_component] += bin.C[k][i_component];
}
}
_renormalize(bin_sum);
}
PUGS_INLINE
static TinyVectorType
getValue(const Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
TinyVectorType value = zero;
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
for (size_t k = 0; k < K; ++k) {
value[i_component] += 0.25 * (bin.C[k][i_component] - 6) * ufp(bin.S[k][i_component]) + bin.S[k][i_component];
}
}
return value;
}
PUGS_INLINE Bin
getSummationBin() const noexcept(NO_ASSERT)
{
return m_summation_bin;
}
PUGS_INLINE TinyVectorType
getSum() const noexcept(NO_ASSERT)
{
return getValue(m_summation_bin);
}
ReproducibleTinyVectorSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
{
if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
"when provided, mask must be an array of bool");
}
using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
using TeamMemberT = TeamPolicyT::member_type;
int nx = reproducible_sum_utils::bin_max_size<DataType>;
int ny = std::max(array.size() / nx, 1ul);
const TeamPolicyT policy(ny, Kokkos::AUTO());
Array<TinyVectorType> thread_sum(policy.team_size() * policy.league_size());
Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
bin_by_thread.fill(zero);
Array<TinyVectorType> local_max(policy.team_size() * policy.league_size());
local_max.fill(zero);
parallel_for(
policy, PUGS_LAMBDA(const TeamMemberT& member) {
const int i_team = member.league_rank();
const int i_thread = member.team_rank();
const int thread_id = i_team * member.team_size() + i_thread;
const int league_start = nx * i_team;
const int block_size = [&] {
int size = nx;
if (i_team == member.league_size() - 1) {
size = array.size() - league_start;
}
return size;
}();
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
DataType& m = local_max[thread_id][i_component];
DataType abs_value = std::abs(array[league_start + i][i_component]);
if (abs_value > m) {
m = abs_value;
}
}
}
});
_update(local_max[thread_id], bin_by_thread[thread_id]);
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
TinyVectorType x = array[nx * i_team + i];
for (size_t k = 0; k < K; ++k) {
_split2(bin_by_thread[thread_id].S[k], x);
}
}
});
_renormalize(bin_by_thread[thread_id]);
});
m_summation_bin = bin_by_thread[0];
for (size_t i = 1; i < bin_by_thread.size(); ++i) {
addBinTo(bin_by_thread[i], m_summation_bin);
}
}
~ReproducibleTinyVectorSum() = default;
};
template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleTinyMatrixSum
{
public:
using TinyMatrixType = std::decay_t<typename ArrayT::data_type>;
static_assert(is_tiny_matrix_v<TinyMatrixType>);
using DataType = std::decay_t<typename TinyMatrixType::data_type>;
static_assert(std::is_floating_point_v<DataType>);
static constexpr size_t K = reproducible_sum_utils::bin_number<DataType>;
static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
static constexpr size_t W = reproducible_sum_utils::bin_size<DataType>;
struct Bin
{
std::array<TinyMatrixType, K> S; // sum
std::array<TinyMatrixType, K> C; // carry
Bin& operator=(const Bin&) = default;
Bin& operator=(Bin&&) = default;
Bin(Bin&&) = default;
Bin(const Bin&) = default;
Bin(ZeroType)
{
for (size_t k = 0; k < K; ++k) {
const DataType init_value = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
S[k](i_component, j_component) = init_value;
}
}
C[k] = zero;
}
}
Bin() = default;
~Bin() = default;
};
private:
Bin m_summation_bin;
PUGS_INLINE
static void
_shift(const size_t g, Bin& bin, const size_t i_component, const size_t j_component) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = K - 1; k >= g; --k) {
bin.S[k](i_component, j_component) = bin.S[k - g](i_component, j_component);
bin.C[k](i_component, j_component) = bin.C[k - g](i_component, j_component);
}
for (size_t k = 0; k < std::min(K, g); ++k) {
bin.S[k](i_component, j_component) = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k](i_component, j_component));
bin.C[k](i_component, j_component) = 0;
}
}
PUGS_INLINE static void
_update(const TinyMatrixType& m, Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
if (m(i_component, j_component) >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0](i_component, j_component))) {
const size_t shift =
1 + std::floor(std::log2(m(i_component, j_component) /
(std::pow(DataType{2}, W - 1.) * ulp(bin.S[0](i_component, j_component)))) /
W);
_shift(shift, bin, i_component, j_component);
}
}
}
}
PUGS_INLINE
void static _split2(TinyMatrixType& S, TinyMatrixType& x) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
union
{
DataType as_DataType;
typename IntegerType<DataType>::integer_t as_integer;
} x_bar;
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
DataType& S_ij = S(i_component, j_component);
DataType& x_ij = x(i_component, j_component);
x_bar.as_DataType = x_ij;
x_bar.as_integer |= 0x1;
const DataType S0 = S_ij;
S_ij += x_bar.as_DataType;
x_ij -= S_ij - S0;
}
}
}
PUGS_INLINE static void
_renormalize(Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
for (size_t k = 0; k < K; ++k) {
TinyMatrixType& S_k = bin.S[k];
TinyMatrixType& C_k = bin.C[k];
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
DataType& Sk_ij = S_k(i_component, j_component);
DataType& Ck_ij = C_k(i_component, j_component);
if (Sk_ij >= 1.75 * ufp(Sk_ij)) {
Sk_ij -= 0.25 * ufp(Sk_ij);
Ck_ij += 1;
} else if (Sk_ij < 1.25 * ufp(Sk_ij)) {
Sk_ij += 0.5 * ufp(Sk_ij);
Ck_ij -= 2;
} else if (Sk_ij < 1.5 * ufp(Sk_ij)) {
Sk_ij += 0.25 * ufp(Sk_ij);
Ck_ij -= 1;
}
}
}
}
}
public:
static void
addBinTo(Bin& bin, Bin& bin_sum)
{
using namespace reproducible_sum_utils;
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
DataType ulp_bin = ulp(bin.S[0](i_component, j_component));
DataType ulp_sum = ulp(bin_sum.S[0](i_component, j_component));
if (ulp_bin < ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
if (shift > 0) {
_shift(shift, bin, i_component, j_component);
}
} else if (ulp_bin > ulp_sum) {
const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
if (shift > 0) {
_shift(shift, bin_sum, i_component, j_component);
}
}
for (size_t k = 0; k < K; ++k) {
bin_sum.S[k](i_component, j_component) +=
bin.S[k](i_component, j_component) - 1.5 * ufp(bin.S[k](i_component, j_component));
bin_sum.C[k](i_component, j_component) += bin.C[k](i_component, j_component);
}
}
}
_renormalize(bin_sum);
}
PUGS_INLINE
static TinyMatrixType
getValue(const Bin& bin) noexcept(NO_ASSERT)
{
using namespace reproducible_sum_utils;
TinyMatrixType value = zero;
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
for (size_t k = 0; k < K; ++k) {
value(i_component, j_component) +=
0.25 * (bin.C[k](i_component, j_component) - 6) * ufp(bin.S[k](i_component, j_component)) +
bin.S[k](i_component, j_component);
}
}
}
return value;
}
PUGS_INLINE Bin
getSummationBin() const noexcept(NO_ASSERT)
{
return m_summation_bin;
}
PUGS_INLINE TinyMatrixType
getSum() const noexcept(NO_ASSERT)
{
return getValue(m_summation_bin);
}
ReproducibleTinyMatrixSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
{
if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
"when provided, mask must be an array of bool");
}
using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
using TeamMemberT = TeamPolicyT::member_type;
int nx = reproducible_sum_utils::bin_max_size<DataType>;
int ny = std::max(array.size() / nx, 1ul);
const TeamPolicyT policy(ny, Kokkos::AUTO());
Array<TinyMatrixType> thread_sum(policy.team_size() * policy.league_size());
Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
bin_by_thread.fill(zero);
Array<TinyMatrixType> local_max(policy.team_size() * policy.league_size());
local_max.fill(zero);
parallel_for(
policy, PUGS_LAMBDA(const TeamMemberT& member) {
const int i_team = member.league_rank();
const int i_thread = member.team_rank();
const int thread_id = i_team * member.team_size() + i_thread;
const int league_start = nx * i_team;
const int block_size = [&] {
int size = nx;
if (i_team == member.league_size() - 1) {
size = array.size() - league_start;
}
return size;
}();
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
DataType& m = local_max[thread_id](i_component, j_component);
DataType abs_value = std::abs(array[league_start + i](i_component, j_component));
if (abs_value > m) {
m = abs_value;
}
}
}
}
});
_update(local_max[thread_id], bin_by_thread[thread_id]);
parallel_for(
Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
if (mask[league_start + i]) {
TinyMatrixType x = array[nx * i_team + i];
for (size_t k = 0; k < K; ++k) {
_split2(bin_by_thread[thread_id].S[k], x);
}
}
});
_renormalize(bin_by_thread[thread_id]);
});
m_summation_bin = bin_by_thread[0];
for (size_t i = 1; i < bin_by_thread.size(); ++i) {
addBinTo(bin_by_thread[i], m_summation_bin);
}
}
~ReproducibleTinyMatrixSum() = default;
};
#endif // REPRODUCIBLE_SUM_UTILS_HPP