#ifndef CRS_MATRIX_HPP
#define CRS_MATRIX_HPP

#include <algebra/Vector.hpp>
#include <utils/Array.hpp>
#include <utils/NaNHelper.hpp>
#include <utils/PugsAssert.hpp>

#include <iostream>

template <typename DataType, typename IndexType>
class CRSMatrixDescriptor;

template <typename DataType = double, typename IndexType = int>
class CRSMatrix
{
 public:
  using MutableDataType = std::remove_const_t<DataType>;
  using index_type      = IndexType;
  using data_type       = DataType;

 private:
  const IndexType m_nb_rows;
  const IndexType m_nb_columns;

  Array<const IndexType> m_row_map;
  Array<const DataType> m_values;
  Array<const IndexType> m_column_indices;

 public:
  PUGS_INLINE
  bool
  isSquare() const noexcept
  {
    return m_nb_rows == m_nb_columns;
  }

  PUGS_INLINE
  IndexType
  numberOfRows() const
  {
    return m_nb_rows;
  }

  PUGS_INLINE
  IndexType
  numberOfColumns() const
  {
    return m_nb_columns;
  }

  auto
  values() const
  {
    return m_values;
  }

  auto
  rowMap() const
  {
    return m_row_map;
  }

  auto
  columnIndices() const
  {
    return m_column_indices;
  }

  template <typename DataType2>
  Vector<MutableDataType>
  operator*(const Vector<DataType2>& x) const
  {
    static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
                  "Cannot multiply matrix and vector of different type");

    Assert(x.size() - m_nb_columns == 0);

    Vector<MutableDataType> Ax(m_nb_rows);

    parallel_for(
      m_nb_rows, PUGS_LAMBDA(const IndexType& i_row) {
        const auto row_begin = m_row_map[i_row];
        const auto row_end   = m_row_map[i_row + 1];
        MutableDataType sum{0};
        for (IndexType j = row_begin; j < row_end; ++j) {
          sum += m_values[j] * x[m_column_indices[j]];
        }
        Ax[i_row] = sum;
      });

    return Ax;
  }

  friend PUGS_INLINE std::ostream&
  operator<<(std::ostream& os, const CRSMatrix& A)
  {
    for (IndexType i = 0; i < A.m_nb_rows; ++i) {
      const auto row_begin = A.m_row_map[i];
      const auto row_end   = A.m_row_map[i + 1];
      os << i << "|";
      for (IndexType j = row_begin; j < row_end; ++j) {
        os << ' ' << A.m_column_indices[j] << ':' << NaNHelper(A.m_values[j]);
      }
      os << '\n';
    }
    return os;
  }

 private:
  friend class CRSMatrixDescriptor<DataType, IndexType>;

  CRSMatrix(const IndexType nb_rows,
            const IndexType nb_columns,
            const Array<const IndexType>& row_map,
            const Array<const DataType>& values,
            const Array<const IndexType>& column_indices)
    : m_nb_rows{nb_rows},
      m_nb_columns{nb_columns},
      m_row_map{row_map},
      m_values{values},
      m_column_indices{column_indices}
  {}

 public:
  ~CRSMatrix() = default;
};

#endif   // CRS_MATRIX_HPP