#ifndef REPRODUCIBLE_SUM_UTILS_HPP
#define REPRODUCIBLE_SUM_UTILS_HPP

#include <utils/PugsUtils.hpp>
#include <utils/Types.hpp>

template <typename DataType>
class Array;

namespace reproducible_sum_utils
{
template <size_t NumberOfBits>
struct IntegerFromBitSize
{
};

template <>
struct IntegerFromBitSize<8>
{
  using integer_t = std::int8_t;
};

template <>
struct IntegerFromBitSize<16>
{
  using integer_t = std::int16_t;
};

template <>
struct IntegerFromBitSize<32>
{
  using integer_t = std::int32_t;
};

template <>
struct IntegerFromBitSize<64>
{
  using integer_t = std::int64_t;
};

template <typename DataType>
struct IntegerType
{
  using integer_t = typename IntegerFromBitSize<sizeof(DataType) * 8>::integer_t;
};

template <typename DataType>
DataType
ulp(const DataType& x) noexcept(NO_ASSERT)
{
  static_assert(std::is_floating_point_v<DataType>, "expecting floating point value");

  if (x == 0) {
    return std::numeric_limits<DataType>::denorm_min();
  }

  return std::pow(DataType{2}, std::ilogb(std::abs(x)) - std::numeric_limits<DataType>::digits);
}

template <typename DataType>
DataType
ufp(const DataType& x) noexcept(NO_ASSERT)
{
  static_assert(std::is_floating_point_v<DataType>, "expecting floating point value");

  return std::pow(DataType{2}, std::ilogb(std::abs(x)));
}

// Useful bits per bin
template <typename DataType>
constexpr inline size_t bin_size = 0;
template <>
constexpr inline size_t bin_size<double> = 40;
template <>
constexpr inline size_t bin_size<float> = 12;

// IEEE 754 epsilon
template <typename DataType>
constexpr inline double bin_epsilon = 0;
template <>
constexpr inline double bin_epsilon<double> = std::numeric_limits<double>::epsilon();
template <>
constexpr inline double bin_epsilon<float> = std::numeric_limits<float>::epsilon();

// number of bins: improves precision
template <typename DataType>
constexpr inline size_t bin_number = 0;

template <>
constexpr inline size_t bin_number<double> = 3;
template <>
constexpr inline size_t bin_number<float> = 4;

// max local sum size to avoid overflow
template <typename DataType>
constexpr inline size_t bin_max_size = 0;

template <>
constexpr inline size_t bin_max_size<double> = 2048;
template <>
constexpr inline size_t bin_max_size<float> = 1024;

struct NoMask
{
  PUGS_INLINE bool
  operator[](size_t) const
  {
    return true;
  }
};

}   // namespace reproducible_sum_utils

template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleScalarSum
{
 public:
  using DataType = std::decay_t<typename ArrayT::data_type>;
  static_assert(std::is_floating_point_v<DataType>);

  static constexpr size_t K     = reproducible_sum_utils::bin_number<DataType>;
  static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
  static constexpr size_t W     = reproducible_sum_utils::bin_size<DataType>;

  struct Bin
  {
    std::array<DataType, K> S;   // sum
    std::array<DataType, K> C;   // carry

    Bin& operator=(const Bin&) = default;
    Bin& operator=(Bin&&)      = default;

    Bin(Bin&&)      = default;
    Bin(const Bin&) = default;

    Bin(ZeroType)
    {
      for (size_t k = 0; k < K; ++k) {
        S[k] = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
        C[k] = 0;
      }
    }

    Bin() = default;

    ~Bin() = default;
  };

 private:
  Bin m_summation_bin;

  PUGS_INLINE
  static void
  _shift(const size_t g, Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = K - 1; k >= g; --k) {
      bin.S[k] = bin.S[k - g];
      bin.C[k] = bin.C[k - g];
    }
    for (size_t k = 0; k < std::min(K, g); ++k) {
      bin.S[k] = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k]);
      bin.C[k] = 0;
    }
  }

  PUGS_INLINE static void
  _update(const DataType& m, Bin& bin) noexcept(NO_ASSERT)
  {
    Assert(m >= 0);

    using namespace reproducible_sum_utils;
    if (m >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0])) {
      const size_t shift = 1 + std::floor(std::log2(m / (std::pow(DataType{2}, W - 1.) * ulp(bin.S[0]))) / W);

      _shift(shift, bin);
    }
  }

  PUGS_INLINE
  void static _split2(DataType& S, DataType& x) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;
    union
    {
      DataType as_DataType;
      typename IntegerType<DataType>::integer_t as_integer;
    } x_bar;

    x_bar.as_DataType = x;
    x_bar.as_integer |= 0x1;

    const DataType S0 = S;
    S += x_bar.as_DataType;
    x -= S - S0;
  }

  PUGS_INLINE static void
  _renormalize(Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = 0; k < K; ++k) {
      if (bin.S[k] >= 1.75 * ufp(bin.S[k])) {
        bin.S[k] -= 0.25 * ufp(bin.S[k]);
        bin.C[k] += 1;
      } else if (bin.S[k] < 1.25 * ufp(bin.S[k])) {
        bin.S[k] += 0.5 * ufp(bin.S[k]);
        bin.C[k] -= 2;
      } else if (bin.S[k] < 1.5 * ufp(bin.S[k])) {
        bin.S[k] += 0.25 * ufp(bin.S[k]);
        bin.C[k] -= 1;
      }
    }
  }

 public:
  static void
  addBinTo(Bin& bin, Bin& bin_sum)
  {
    using namespace reproducible_sum_utils;

    DataType ulp_bin = ulp(bin.S[0]);
    DataType ulp_sum = ulp(bin_sum.S[0]);

    if (ulp_bin < ulp_sum) {
      const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
      if (shift > 0) {
        _shift(shift, bin);
      }
    } else if (ulp_bin > ulp_sum) {
      const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
      if (shift > 0) {
        _shift(shift, bin_sum);
      }
    }

    for (size_t k = 0; k < K; ++k) {
      bin_sum.S[k] += bin.S[k] - 1.5 * ufp(bin.S[k]);
      bin_sum.C[k] += bin.C[k];
    }

    _renormalize(bin_sum);
  }

  PUGS_INLINE
  static DataType
  getValue(const Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    DataType value = 0;
    for (size_t k = 0; k < K; ++k) {
      value += 0.25 * (bin.C[k] - 6) * ufp(bin.S[k]) + bin.S[k];
    }
    return value;
  }

  PUGS_INLINE Bin
  getSummationBin() const noexcept(NO_ASSERT)
  {
    return m_summation_bin;
  }

  PUGS_INLINE DataType
  getSum() const noexcept(NO_ASSERT)
  {
    return getValue(m_summation_bin);
  }

  ReproducibleScalarSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
  {
    if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
      static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
                    "when provided, mask must be an array of bool");
    }

    using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
    using TeamMemberT = TeamPolicyT::member_type;

    int nx = reproducible_sum_utils::bin_max_size<DataType>;
    int ny = std::max(array.size() / nx, 1ul);

    const TeamPolicyT policy(ny, Kokkos::AUTO());

    Array<DataType> thread_sum(policy.team_size() * policy.league_size());

    Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
    bin_by_thread.fill(zero);

    Array<DataType> local_max(policy.team_size() * policy.league_size());
    local_max.fill(0);

    parallel_for(
      policy, PUGS_LAMBDA(const TeamMemberT& member) {
        const int i_team   = member.league_rank();
        const int i_thread = member.team_rank();

        const int thread_id = i_team * member.team_size() + i_thread;

        const int league_start = nx * i_team;
        const int block_size   = [&] {
          int size = nx;
          if (i_team == member.league_size() - 1) {
            size = array.size() - league_start;
          }
          return size;
        }();

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              DataType& m        = local_max[thread_id];
              DataType abs_value = std::abs(array[league_start + i]);
              if (abs_value > m) {
                m = abs_value;
              }
            }
          });

        _update(local_max[thread_id], bin_by_thread[thread_id]);

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              DataType x = array[nx * i_team + i];
              for (size_t k = 0; k < K; ++k) {
                _split2(bin_by_thread[thread_id].S[k], x);
              };
            }
          });

        _renormalize(bin_by_thread[thread_id]);
      });

    m_summation_bin = bin_by_thread[0];
    for (size_t i = 1; i < bin_by_thread.size(); ++i) {
      addBinTo(bin_by_thread[i], m_summation_bin);
    }
  }

  ~ReproducibleScalarSum() = default;
};

template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleTinyVectorSum
{
 public:
  using TinyVectorType = std::decay_t<typename ArrayT::data_type>;
  static_assert(is_tiny_vector_v<TinyVectorType>);

  using DataType = std::decay_t<typename TinyVectorType::data_type>;
  static_assert(std::is_floating_point_v<DataType>);

  static constexpr size_t K     = reproducible_sum_utils::bin_number<DataType>;
  static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
  static constexpr size_t W     = reproducible_sum_utils::bin_size<DataType>;

  struct Bin
  {
    std::array<TinyVectorType, K> S;   // sum
    std::array<TinyVectorType, K> C;   // carry

    Bin& operator=(const Bin&) = default;
    Bin& operator=(Bin&&)      = default;

    Bin(Bin&&)      = default;
    Bin(const Bin&) = default;

    Bin(ZeroType)
    {
      for (size_t k = 0; k < K; ++k) {
        const DataType init_value = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
        for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
          S[k][i_component] = init_value;
        }
        C[k] = zero;
      }
    }

    Bin() = default;

    ~Bin() = default;
  };

 private:
  Bin m_summation_bin;

  PUGS_INLINE
  static void
  _shift(const size_t g, Bin& bin, const size_t& i_component) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = K - 1; k >= g; --k) {
      bin.S[k][i_component] = bin.S[k - g][i_component];
      bin.C[k][i_component] = bin.C[k - g][i_component];
    }
    for (size_t k = 0; k < std::min(K, g); ++k) {
      bin.S[k][i_component] = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k][i_component]);
      bin.C[k][i_component] = 0;
    }
  }

  PUGS_INLINE static void
  _update(const TinyVectorType& m, Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;
    for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
      if (m[i_component] >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0][i_component])) {
        const size_t shift =
          1 + std::floor(std::log2(m[i_component] / (std::pow(DataType{2}, W - 1.) * ulp(bin.S[0][i_component]))) / W);

        _shift(shift, bin, i_component);
      }
    }
  }

  PUGS_INLINE
  void static _split2(TinyVectorType& S, TinyVectorType& x) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;
    union
    {
      DataType as_DataType;
      typename IntegerType<DataType>::integer_t as_integer;
    } x_bar;

    for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
      x_bar.as_DataType = x[i_component];
      x_bar.as_integer |= 0x1;

      DataType& S_i = S[i_component];
      DataType& x_i = x[i_component];

      const DataType S0 = S_i;
      S_i += x_bar.as_DataType;
      x_i -= S_i - S0;
    }
  }

  PUGS_INLINE static void
  _renormalize(Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = 0; k < K; ++k) {
      TinyVectorType& S_k = bin.S[k];
      TinyVectorType& C_k = bin.C[k];
      for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
        DataType& Sk_i = S_k[i_component];
        DataType& Ck_i = C_k[i_component];

        if (Sk_i >= 1.75 * ufp(Sk_i)) {
          Sk_i -= 0.25 * ufp(Sk_i);
          Ck_i += 1;
        } else if (Sk_i < 1.25 * ufp(Sk_i)) {
          Sk_i += 0.5 * ufp(Sk_i);
          Ck_i -= 2;
        } else if (Sk_i < 1.5 * ufp(Sk_i)) {
          Sk_i += 0.25 * ufp(Sk_i);
          Ck_i -= 1;
        }
      }
    }
  }

 public:
  static void
  addBinTo(Bin& bin, Bin& bin_sum)
  {
    using namespace reproducible_sum_utils;

    for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
      DataType ulp_bin = ulp(bin.S[0][i_component]);
      DataType ulp_sum = ulp(bin_sum.S[0][i_component]);

      if (ulp_bin < ulp_sum) {
        const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
        if (shift > 0) {
          _shift(shift, bin, i_component);
        }
      } else if (ulp_bin > ulp_sum) {
        const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
        if (shift > 0) {
          _shift(shift, bin_sum, i_component);
        }
      }

      for (size_t k = 0; k < K; ++k) {
        bin_sum.S[k][i_component] += bin.S[k][i_component] - 1.5 * ufp(bin.S[k][i_component]);
        bin_sum.C[k][i_component] += bin.C[k][i_component];
      }
    }

    _renormalize(bin_sum);
  }

  PUGS_INLINE
  static TinyVectorType
  getValue(const Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    TinyVectorType value = zero;
    for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
      for (size_t k = 0; k < K; ++k) {
        value[i_component] += 0.25 * (bin.C[k][i_component] - 6) * ufp(bin.S[k][i_component]) + bin.S[k][i_component];
      }
    }
    return value;
  }

  PUGS_INLINE Bin
  getSummationBin() const noexcept(NO_ASSERT)
  {
    return m_summation_bin;
  }

  PUGS_INLINE TinyVectorType
  getSum() const noexcept(NO_ASSERT)
  {
    return getValue(m_summation_bin);
  }

  ReproducibleTinyVectorSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
  {
    if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
      static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
                    "when provided, mask must be an array of bool");
    }

    using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
    using TeamMemberT = TeamPolicyT::member_type;

    int nx = reproducible_sum_utils::bin_max_size<DataType>;
    int ny = std::max(array.size() / nx, 1ul);

    const TeamPolicyT policy(ny, Kokkos::AUTO());

    Array<TinyVectorType> thread_sum(policy.team_size() * policy.league_size());

    Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
    bin_by_thread.fill(zero);

    Array<TinyVectorType> local_max(policy.team_size() * policy.league_size());
    local_max.fill(zero);

    parallel_for(
      policy, PUGS_LAMBDA(const TeamMemberT& member) {
        const int i_team   = member.league_rank();
        const int i_thread = member.team_rank();

        const int thread_id = i_team * member.team_size() + i_thread;

        const int league_start = nx * i_team;
        const int block_size   = [&] {
          int size = nx;
          if (i_team == member.league_size() - 1) {
            size = array.size() - league_start;
          }
          return size;
        }();

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              for (size_t i_component = 0; i_component < TinyVectorType::Dimension; ++i_component) {
                DataType& m        = local_max[thread_id][i_component];
                DataType abs_value = std::abs(array[league_start + i][i_component]);
                if (abs_value > m) {
                  m = abs_value;
                }
              }
            }
          });

        _update(local_max[thread_id], bin_by_thread[thread_id]);

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              TinyVectorType x = array[nx * i_team + i];
              for (size_t k = 0; k < K; ++k) {
                _split2(bin_by_thread[thread_id].S[k], x);
              }
            }
          });

        _renormalize(bin_by_thread[thread_id]);
      });

    m_summation_bin = bin_by_thread[0];
    for (size_t i = 1; i < bin_by_thread.size(); ++i) {
      addBinTo(bin_by_thread[i], m_summation_bin);
    }
  }

  ~ReproducibleTinyVectorSum() = default;
};

template <typename ArrayT, typename MaskType = reproducible_sum_utils::NoMask>
class ReproducibleTinyMatrixSum
{
 public:
  using TinyMatrixType = std::decay_t<typename ArrayT::data_type>;
  static_assert(is_tiny_matrix_v<TinyMatrixType>);

  using DataType = std::decay_t<typename TinyMatrixType::data_type>;
  static_assert(std::is_floating_point_v<DataType>);

  static constexpr size_t K     = reproducible_sum_utils::bin_number<DataType>;
  static constexpr DataType eps = reproducible_sum_utils::bin_epsilon<DataType>;
  static constexpr size_t W     = reproducible_sum_utils::bin_size<DataType>;

  struct Bin
  {
    std::array<TinyMatrixType, K> S;   // sum
    std::array<TinyMatrixType, K> C;   // carry

    Bin& operator=(const Bin&) = default;
    Bin& operator=(Bin&&)      = default;

    Bin(Bin&&)      = default;
    Bin(const Bin&) = default;

    Bin(ZeroType)
    {
      for (size_t k = 0; k < K; ++k) {
        const DataType init_value = 0.75 * eps * std::pow(DataType{2}, (K - k - 1) * W);
        for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
          for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
            S[k](i_component, j_component) = init_value;
          }
        }
        C[k] = zero;
      }
    }

    Bin()  = default;
    ~Bin() = default;
  };

 private:
  Bin m_summation_bin;

  PUGS_INLINE
  static void
  _shift(const size_t g, Bin& bin, const size_t i_component, const size_t j_component) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = K - 1; k >= g; --k) {
      bin.S[k](i_component, j_component) = bin.S[k - g](i_component, j_component);
      bin.C[k](i_component, j_component) = bin.C[k - g](i_component, j_component);
    }
    for (size_t k = 0; k < std::min(K, g); ++k) {
      bin.S[k](i_component, j_component) = 1.5 * std::pow(DataType{2}, g * W) * ufp(bin.S[k](i_component, j_component));
      bin.C[k](i_component, j_component) = 0;
    }
  }

  PUGS_INLINE static void
  _update(const TinyMatrixType& m, Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;
    for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
      for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
        if (m(i_component, j_component) >= std::pow(DataType{2}, W - 1.) * ulp(bin.S[0](i_component, j_component))) {
          const size_t shift =
            1 + std::floor(std::log2(m(i_component, j_component) /
                                     (std::pow(DataType{2}, W - 1.) * ulp(bin.S[0](i_component, j_component)))) /
                           W);

          _shift(shift, bin, i_component, j_component);
        }
      }
    }
  }

  PUGS_INLINE
  void static _split2(TinyMatrixType& S, TinyMatrixType& x) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;
    union
    {
      DataType as_DataType;
      typename IntegerType<DataType>::integer_t as_integer;
    } x_bar;

    for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
      for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
        DataType& S_ij = S(i_component, j_component);
        DataType& x_ij = x(i_component, j_component);

        x_bar.as_DataType = x_ij;
        x_bar.as_integer |= 0x1;

        const DataType S0 = S_ij;
        S_ij += x_bar.as_DataType;
        x_ij -= S_ij - S0;
      }
    }
  }

  PUGS_INLINE static void
  _renormalize(Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    for (size_t k = 0; k < K; ++k) {
      TinyMatrixType& S_k = bin.S[k];
      TinyMatrixType& C_k = bin.C[k];
      for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
        for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
          DataType& Sk_ij = S_k(i_component, j_component);
          DataType& Ck_ij = C_k(i_component, j_component);

          if (Sk_ij >= 1.75 * ufp(Sk_ij)) {
            Sk_ij -= 0.25 * ufp(Sk_ij);
            Ck_ij += 1;
          } else if (Sk_ij < 1.25 * ufp(Sk_ij)) {
            Sk_ij += 0.5 * ufp(Sk_ij);
            Ck_ij -= 2;
          } else if (Sk_ij < 1.5 * ufp(Sk_ij)) {
            Sk_ij += 0.25 * ufp(Sk_ij);
            Ck_ij -= 1;
          }
        }
      }
    }
  }

 public:
  static void
  addBinTo(Bin& bin, Bin& bin_sum)
  {
    using namespace reproducible_sum_utils;

    for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
      for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
        DataType ulp_bin = ulp(bin.S[0](i_component, j_component));
        DataType ulp_sum = ulp(bin_sum.S[0](i_component, j_component));

        if (ulp_bin < ulp_sum) {
          const size_t shift = std::floor(std::log2(ulp_sum / ulp_bin) / W);
          if (shift > 0) {
            _shift(shift, bin, i_component, j_component);
          }
        } else if (ulp_bin > ulp_sum) {
          const size_t shift = std::floor(std::log2(ulp_bin / ulp_sum) / W);
          if (shift > 0) {
            _shift(shift, bin_sum, i_component, j_component);
          }
        }

        for (size_t k = 0; k < K; ++k) {
          bin_sum.S[k](i_component, j_component) +=
            bin.S[k](i_component, j_component) - 1.5 * ufp(bin.S[k](i_component, j_component));
          bin_sum.C[k](i_component, j_component) += bin.C[k](i_component, j_component);
        }
      }
    }

    _renormalize(bin_sum);
  }

  PUGS_INLINE
  static TinyMatrixType
  getValue(const Bin& bin) noexcept(NO_ASSERT)
  {
    using namespace reproducible_sum_utils;

    TinyMatrixType value = zero;
    for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
      for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
        for (size_t k = 0; k < K; ++k) {
          value(i_component, j_component) +=
            0.25 * (bin.C[k](i_component, j_component) - 6) * ufp(bin.S[k](i_component, j_component)) +
            bin.S[k](i_component, j_component);
        }
      }
    }
    return value;
  }

  PUGS_INLINE Bin
  getSummationBin() const noexcept(NO_ASSERT)
  {
    return m_summation_bin;
  }

  PUGS_INLINE TinyMatrixType
  getSum() const noexcept(NO_ASSERT)
  {
    return getValue(m_summation_bin);
  }

  ReproducibleTinyMatrixSum(const ArrayT& array, const MaskType mask = reproducible_sum_utils::NoMask{})
  {
    if constexpr (not std::is_same_v<MaskType, reproducible_sum_utils::NoMask>) {
      static_assert(std::is_same_v<std::decay_t<typename MaskType::data_type>, bool>,
                    "when provided, mask must be an array of bool");
    }

    using TeamPolicyT = Kokkos::TeamPolicy<Kokkos::IndexType<int>>;
    using TeamMemberT = TeamPolicyT::member_type;

    int nx = reproducible_sum_utils::bin_max_size<DataType>;
    int ny = std::max(array.size() / nx, 1ul);

    const TeamPolicyT policy(ny, Kokkos::AUTO());

    Array<TinyMatrixType> thread_sum(policy.team_size() * policy.league_size());

    Array<Bin> bin_by_thread(policy.team_size() * policy.league_size());
    bin_by_thread.fill(zero);

    Array<TinyMatrixType> local_max(policy.team_size() * policy.league_size());
    local_max.fill(zero);

    parallel_for(
      policy, PUGS_LAMBDA(const TeamMemberT& member) {
        const int i_team   = member.league_rank();
        const int i_thread = member.team_rank();

        const int thread_id = i_team * member.team_size() + i_thread;

        const int league_start = nx * i_team;
        const int block_size   = [&] {
          int size = nx;
          if (i_team == member.league_size() - 1) {
            size = array.size() - league_start;
          }
          return size;
        }();

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              for (size_t i_component = 0; i_component < TinyMatrixType::NumberOfRows; ++i_component) {
                for (size_t j_component = 0; j_component < TinyMatrixType::NumberOfColumns; ++j_component) {
                  DataType& m        = local_max[thread_id](i_component, j_component);
                  DataType abs_value = std::abs(array[league_start + i](i_component, j_component));
                  if (abs_value > m) {
                    m = abs_value;
                  }
                }
              }
            }
          });

        _update(local_max[thread_id], bin_by_thread[thread_id]);

        parallel_for(
          Kokkos::TeamThreadRange(member, block_size), PUGS_LAMBDA(int i) {
            if (mask[league_start + i]) {
              TinyMatrixType x = array[nx * i_team + i];
              for (size_t k = 0; k < K; ++k) {
                _split2(bin_by_thread[thread_id].S[k], x);
              }
            }
          });

        _renormalize(bin_by_thread[thread_id]);
      });

    m_summation_bin = bin_by_thread[0];
    for (size_t i = 1; i < bin_by_thread.size(); ++i) {
      addBinTo(bin_by_thread[i], m_summation_bin);
    }
  }

  ~ReproducibleTinyMatrixSum() = default;
};

#endif   // REPRODUCIBLE_SUM_UTILS_HPP