#ifndef CRS_MATRIX_DESCRIPTOR_HPP
#define CRS_MATRIX_DESCRIPTOR_HPP

#include <algebra/CRSMatrix.hpp>
#include <utils/Array.hpp>
#include <utils/Exceptions.hpp>

#include <map>

template <typename DataType = double, typename IndexType = int>
class CRSMatrixDescriptor
{
 public:
  using index_type = IndexType;
  using data_type  = DataType;

  static_assert(std::is_integral_v<IndexType>);

 private:
  const IndexType m_nb_rows;
  const IndexType m_nb_columns;

  Array<const IndexType> m_row_map;
  Array<data_type> m_values;
  Array<IndexType> m_column_indices;

  Array<IndexType> m_nb_values_per_row;
  Array<std::map<IndexType, DataType>> m_overflow_per_row;

  Array<const IndexType>
  _computeRowMap(const Array<IndexType>& non_zeros) const
  {
    Assert(non_zeros.size() - m_nb_rows == 0);

    Array<IndexType> row_map(m_nb_rows + 1);
    row_map[0] = 0;
    for (IndexType i = 0; i < m_nb_rows; ++i) {
      Assert((non_zeros[i] >= 0) and (non_zeros[i] <= m_nb_columns));
      row_map[i + 1] = row_map[i] + non_zeros[i];
    }

    return row_map;
  }

 public:
  IndexType
  numberOfRows() const
  {
    return m_nb_rows;
  }

  IndexType
  numberOfColumns() const
  {
    return m_nb_columns;
  }

  bool PUGS_INLINE
  hasOverflow() const
  {
    for (IndexType i = 0; i < m_nb_rows; ++i) {
      if (m_overflow_per_row[i].size() > 0) {
        return true;
      }
    }
    return false;
  }

  bool PUGS_INLINE
  isFilled() const
  {
    for (IndexType i = 0; i < m_nb_rows; ++i) {
      if (m_nb_values_per_row[i] < m_row_map[i + 1] - m_row_map[i]) {
        return false;
      }
    }
    return true;
  }

  friend PUGS_INLINE std::ostream&
  operator<<(std::ostream& os, const CRSMatrixDescriptor& A)
  {
    for (IndexType i = 0; i < A.m_nb_rows; ++i) {
      const auto& overflow_row = A.m_overflow_per_row[i];
      os << i << "|";

      if (A.m_nb_values_per_row[i] + overflow_row.size() > 0) {
        IndexType j     = 0;
        auto i_overflow = overflow_row.begin();
        while (j < A.m_nb_values_per_row[i] or i_overflow != overflow_row.end()) {
          const IndexType j_index = [&] {
            if (j < A.m_nb_values_per_row[i]) {
              return A.m_column_indices[A.m_row_map[i] + j];
            } else {
              return std::numeric_limits<IndexType>::max();
            }
          }();

          const IndexType overflow_index = [&] {
            if (i_overflow != overflow_row.end()) {
              return i_overflow->first;
            } else {
              return std::numeric_limits<IndexType>::max();
            }
          }();

          if (j_index < overflow_index) {
            os << ' ' << A.m_column_indices[A.m_row_map[i] + j] << ':' << A.m_values[A.m_row_map[i] + j];
            ++j;
          } else {
            os << ' ' << i_overflow->first << ':' << i_overflow->second;
            ++i_overflow;
          }
        }
      }

      os << '\n';
    }
    return os;
  }

  [[nodiscard]] PUGS_INLINE DataType&
  operator()(const IndexType i, const IndexType j)
  {
    Assert(i < m_nb_rows, "invalid row index");
    Assert(j < m_nb_columns, "invalid column index");

    const IndexType row_start = m_row_map[i];
    const IndexType row_end   = m_row_map[i + 1];

    IndexType position_candidate = 0;
    bool found                   = false;
    for (IndexType i_row = 0; i_row < m_nb_values_per_row[i]; ++i_row) {
      if (m_column_indices[row_start + i_row] > j) {
        position_candidate = i_row;
        break;
      } else {
        if (m_column_indices[row_start + i_row] == j) {
          position_candidate = i_row;
          found              = true;
          break;
        } else {
          ++position_candidate;
        }
      }
    }

    if (not found) {
      if (m_nb_values_per_row[i] < row_end - row_start) {
        for (IndexType i_shift = 0; i_shift < m_nb_values_per_row[i] - position_candidate; ++i_shift) {
          Assert(std::make_signed_t<IndexType>(row_start + m_nb_values_per_row[i]) -
                   std::make_signed_t<IndexType>(i_shift + 1) >=
                 0);
          const IndexType i_destination = row_start + m_nb_values_per_row[i] - i_shift;
          const IndexType i_source      = i_destination - 1;

          m_column_indices[i_destination] = m_column_indices[i_source];
          m_values[i_destination]         = m_values[i_source];
        }

        m_column_indices[row_start + position_candidate] = j;
        m_values[row_start + position_candidate]         = 0;
        ++m_nb_values_per_row[i];

        return m_values[row_start + position_candidate];
      } else {
        auto& overflow_row = m_overflow_per_row[i];

        auto iterator = overflow_row.insert(std::make_pair(j, DataType{0})).first;

        return iterator->second;
      }
    } else {
      return m_values[row_start + position_candidate];
    }
  }

  CRSMatrix<DataType, IndexType>
  getCRSMatrix() const
  {
    const bool is_filled    = this->isFilled();
    const bool has_overflow = this->hasOverflow();
    if (is_filled and not has_overflow) {
      return CRSMatrix<DataType, IndexType>{m_nb_rows, m_nb_columns, m_row_map, m_values, m_column_indices};
    } else {
      std::cout << rang::fgB::yellow << "warning:" << rang::style::reset
                << " CRSMatrix profile was not properly set:\n";
      if (not is_filled) {
        std::cout << "- some rows are " << rang::style::bold << "too long" << rang::style::reset << '\n';
      }
      if (has_overflow) {
        std::cout << "- some rows are " << rang::style::bold << "too short" << rang::style::reset << '\n';
      }

      Array<IndexType> row_map(m_nb_rows + 1);

      row_map[0] = 0;
      for (IndexType i = 0; i < m_nb_rows; ++i) {
        row_map[i + 1] = row_map[i] + m_nb_values_per_row[i] + m_overflow_per_row[i].size();
      }

      IndexType nb_values = row_map[row_map.size() - 1];

      Array<data_type> values(nb_values);
      Array<IndexType> column_indices(nb_values);

      IndexType l = 0;
      for (IndexType i = 0; i < m_nb_rows; ++i) {
        const auto& overflow_row = m_overflow_per_row[i];

        if (m_nb_values_per_row[i] + overflow_row.size() > 0) {
          IndexType j     = 0;
          auto i_overflow = overflow_row.begin();
          while (j < m_nb_values_per_row[i] or i_overflow != overflow_row.end()) {
            const IndexType j_index = [&] {
              if (j < m_nb_values_per_row[i]) {
                return m_column_indices[m_row_map[i] + j];
              } else {
                return std::numeric_limits<IndexType>::max();
              }
            }();

            const IndexType overflow_index = [&] {
              if (i_overflow != overflow_row.end()) {
                return i_overflow->first;
              } else {
                return std::numeric_limits<IndexType>::max();
              }
            }();

            if (j_index < overflow_index) {
              column_indices[l] = m_column_indices[m_row_map[i] + j];
              values[l]         = m_values[m_row_map[i] + j];
              ++j;

            } else {
              column_indices[l] = i_overflow->first;
              values[l]         = i_overflow->second;
              ++i_overflow;
            }
            ++l;
          }
        }
      }

      return CRSMatrix<DataType, IndexType>{m_nb_rows, m_nb_columns, row_map, values, column_indices};
    }
  }

  CRSMatrixDescriptor(const IndexType nb_rows, const IndexType nb_columns, const Array<IndexType>& non_zeros)
    : m_nb_rows{nb_rows},
      m_nb_columns{nb_columns},
      m_row_map{this->_computeRowMap(non_zeros)},
      m_values(m_row_map[m_row_map.size() - 1]),
      m_column_indices(m_row_map[m_row_map.size() - 1]),
      m_nb_values_per_row(m_nb_rows),
      m_overflow_per_row(m_nb_rows)
  {
    m_nb_values_per_row.fill(0);
    m_values.fill(0);

    // Diagonal is always set to fulfill PETSc's requirements
    if (m_nb_columns == m_nb_rows) {
      for (IndexType i = 0; i < m_nb_rows; ++i) {
        this->operator()(i, i) = DataType{0};
      }
    }
  }

  ~CRSMatrixDescriptor() = default;
};

#endif   // CRS_MATRIX_DESCRIPTOR_HPP