#ifndef CRS_MATRIX_HPP
#define CRS_MATRIX_HPP

#include <Array.hpp>
#include <Kokkos_StaticCrsGraph.hpp>
#include <PugsAssert.hpp>

#include <SparseMatrixDescriptor.hpp>

#include <Vector.hpp>

#include <iostream>

template <typename DataType = double, typename IndexType = size_t>
class CRSMatrix
{
  using MutableDataType = std::remove_const_t<DataType>;

 private:
  using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>;

  HostMatrix m_host_matrix;
  Array<const DataType> m_values;

 public:
  PUGS_INLINE
  size_t
  numberOfRows() const
  {
    return m_host_matrix.numRows();
  }

  template <typename DataType2>
  Vector<MutableDataType> operator*(const Vector<DataType2>& x) const
  {
    static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
                  "Cannot multiply matrix and vector of different type");

    Assert(x.size() == m_host_matrix.numRows());

    Vector<MutableDataType> Ax{m_host_matrix.numRows()};
    auto host_row_map = m_host_matrix.row_map;

    parallel_for(
      m_host_matrix.numRows(), PUGS_LAMBDA(const IndexType& i_row) {
        const auto row_begin = host_row_map(i_row);
        const auto row_end   = host_row_map(i_row + 1);
        MutableDataType sum{0};
        for (IndexType j = row_begin; j < row_end; ++j) {
          sum += m_values[j] * x[m_host_matrix.entries(j)];
        }
        Ax[i_row] = sum;
      });

    return Ax;
  }

  CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M)
  {
    m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector());
    m_values      = M.valueArray();
  }
  ~CRSMatrix() = default;
};

#endif   // CRS_MATRIX_HPP