#ifndef ITEM_VALUE_UTILS_HPP
#define ITEM_VALUE_UTILS_HPP

#include <Messenger.hpp>
#include <ItemValue.hpp>

template <typename DataType,
          ItemType item_type>
std::remove_const_t<DataType>
min(const ItemValue<DataType, item_type>& item_value)
{
  using ItemValueType = ItemValue<DataType, item_type>;
  using data_type = std::remove_const_t<typename ItemValueType::data_type>;
  using index_type = typename ItemValueType::index_type;

  class ItemValueMin
  {
   private:
    const ItemValueType& m_item_value;

   public:
    PASTIS_INLINE
    operator data_type()
    {
      data_type reduced_value;
      parallel_reduce(m_item_value.size(), *this, reduced_value);
      return reduced_value;
    }

    PASTIS_INLINE
    void operator()(const index_type& i, data_type& data) const
    {
      if (m_item_value[i] < data) {
        data = m_item_value[i];
      }
    }

    PASTIS_INLINE
    void join(volatile data_type& dst,
              const volatile data_type& src) const
    {
      if (src < dst) {
        dst = src;
      }
    }

    PASTIS_INLINE
    void init(data_type& value) const
    {
      value = std::numeric_limits<data_type>::max();
    }

    PASTIS_INLINE
    ItemValueMin(const ItemValueType& item_value)
        : m_item_value(item_value)
    {
      ;
    }

    PASTIS_INLINE
    ~ItemValueMin() = default;
  };

  const DataType local_min = ItemValueMin{item_value};
  return parallel::allReduceMin(local_min);
}

template <typename DataType,
          ItemType item_type>
std::remove_const_t<DataType>
max(const ItemValue<DataType, item_type>& item_value)
{
  using ItemValueType = ItemValue<DataType, item_type>;
  using data_type = std::remove_const_t<typename ItemValueType::data_type>;
  using index_type = typename ItemValueType::index_type;

  class ItemValueMax
  {
   private:
    const ItemValueType& m_item_value;

   public:
    PASTIS_INLINE
    operator data_type()
    {
      data_type reduced_value;
      parallel_reduce(m_item_value.size(), *this, reduced_value);
      return reduced_value;
    }

    PASTIS_INLINE
    void operator()(const index_type& i, data_type& data) const
    {
      if (m_item_value[i] > data) {
        data = m_item_value[i];
      }
    }

    PASTIS_INLINE
    void join(volatile data_type& dst,
              const volatile data_type& src) const
    {
      if (src > dst) {
        dst = src;
      }
    }

    PASTIS_INLINE
    void init(data_type& value) const
    {
      value = std::numeric_limits<data_type>::min();
    }

    PASTIS_INLINE
    ItemValueMax(const ItemValueType& item_value)
        : m_item_value(item_value)
    {
      ;
    }

    PASTIS_INLINE
    ~ItemValueMax() = default;
  };

  const DataType local_max = ItemValueMax{item_value};
  return parallel::allReduceMax(local_max);
}


template <typename DataType,
          ItemType item_type>
std::remove_const_t<DataType>
sum(const ItemValue<DataType, item_type>& item_value)
{
  using ItemValueType = ItemValue<DataType, item_type>;
  using data_type = std::remove_const_t<typename ItemValueType::data_type>;
  using index_type = typename ItemValueType::index_type;

  class ItemValueSum
  {
   private:
    const ItemValueType& m_item_value;

   public:
    PASTIS_INLINE
    operator data_type()
    {
      data_type reduced_value;
      parallel_reduce(m_item_value.size(), *this, reduced_value);
      return reduced_value;
    }

    PASTIS_INLINE
    void operator()(const index_type& i, data_type& data) const
    {
      data += m_item_value[i];
    }

    PASTIS_INLINE
    void join(volatile data_type& dst,
              const volatile data_type& src) const
    {
      dst += src;
    }

    PASTIS_INLINE
    void init(data_type& value) const
    {
      if constexpr (std::is_arithmetic_v<data_type>) {
        value = 0;
      } else {
        value = zero;
      }
    }

    PASTIS_INLINE
    ItemValueSum(const ItemValueType& item_value)
        : m_item_value(item_value)
    {
      ;
    }

    PASTIS_INLINE
    ~ItemValueSum() = default;
  };

  const DataType local_sum = ItemValueSum{item_value};
  return parallel::allReduceSum(local_sum);
}

#endif // ITEM_VALUE_UTILS_HPP