#ifndef CRS_MATRIX_HPP
#define CRS_MATRIX_HPP

#include <Array.hpp>
#include <Kokkos_StaticCrsGraph.hpp>
#include <PugsAssert.hpp>

#include <Vector.hpp>

#include <iostream>

class FlexibleMatrix
{
 public:
  class FlexibleRow
  {
    friend class FlexibleMatrix;
    std::map<uint64_t, double> m_id_value_map;

   public:
    const double& operator[](const size_t& j) const
    {
      auto i_value = m_id_value_map.find(j);
      Assert(i_value != m_id_value_map.end());
      return i_value->second;
    }

    double& operator[](const size_t& j)
    {
      auto i_value = m_id_value_map.find(j);
      if (i_value != m_id_value_map.end()) {
        return i_value->second;
      } else {
        auto [i_inserted, success] = m_id_value_map.insert(std::make_pair(j, 0));
        Assert(success);
        return i_inserted->second;
      }
    }

    friend std::ostream&
    operator<<(std::ostream& os, const FlexibleRow& row)
    {
      for (auto [j, value] : row.m_id_value_map) {
        os << ' ' << j << ':' << value;
      }
      return os;
    }

    FlexibleRow() = default;
  };

 private:
  Array<FlexibleRow> m_row_array;

 public:
  FlexibleRow&
  row(const size_t i)
  {
    return m_row_array[i];
  }

  const FlexibleRow&
  row(const size_t i) const
  {
    return m_row_array[i];
  }

  double&
  operator()(const size_t& i, const size_t& j)
  {
    return m_row_array[i][j];
  }

  const double&
  operator()(const size_t& i, const size_t& j) const
  {
    return m_row_array[i][j];
  }

  friend std::ostream&
  operator<<(std::ostream& os, const FlexibleMatrix& M)
  {
    for (size_t i = 0; i < M.m_row_array.size(); ++i) {
      os << i << " |" << M.m_row_array[i] << '\n';
    }
    return os;
  }

  auto
  graphVector() const
  {
    std::vector<std::vector<uint64_t>> graph_vector(m_row_array.size());
    for (size_t i_row = 0; i_row < m_row_array.size(); ++i_row) {
      const FlexibleRow& row = m_row_array[i_row];
      for (auto [id, value] : row.m_id_value_map) {
        graph_vector[i_row].push_back(id);
      }
    }
    return graph_vector;
  }

  Array<double>
  valueArray() const
  {
    size_t size = 0;
    for (size_t i_row = 0; i_row < m_row_array.size(); ++i_row) {
      size += m_row_array[i_row].m_id_value_map.size();
    }

    Array<double> values(size);

    size_t cpt = 0;
    for (size_t i_row = 0; i_row < m_row_array.size(); ++i_row) {
      const FlexibleRow& row = m_row_array[i_row];
      for (auto [id, value] : row.m_id_value_map) {
        values[cpt++] = value;
      }
    }
    return values;
  }

  FlexibleMatrix(size_t nb_row) : m_row_array{nb_row} {}

  ~FlexibleMatrix() = default;
};

class CRSMatrix
{
 private:
  using HostMatrix = Kokkos::StaticCrsGraph<unsigned int, Kokkos::HostSpace>;

  HostMatrix m_host_matrix;
  Array<double> m_values;

 public:
  Vector<double> operator*(const Vector<double>& x) const
  {
    Vector<double> Ax{m_host_matrix.numRows()};
    auto host_row_map = m_host_matrix.row_map;

    parallel_for(
      m_host_matrix.numRows(), PUGS_LAMBDA(const size_t& i_row) {
        const auto row_begin = host_row_map(i_row);
        const auto row_end   = host_row_map(i_row + 1);
        double sum{0};
        for (size_t j = row_begin; j < row_end; ++j) {
          sum += m_values[j] * x[m_host_matrix.entries(j)];
        }
        Ax[i_row] = sum;
      });

    return Ax;
  }

  friend std::ostream&
  operator<<(std::ostream& os, const CRSMatrix& M)
  {
    auto host_row_map = M.m_host_matrix.row_map;
    for (size_t i_row = 0; i_row < M.m_host_matrix.numRows(); ++i_row) {
      const auto& row_begin = host_row_map(i_row);
      const auto& row_end   = host_row_map(i_row + 1);
      os << i_row << " #";
      for (size_t j = row_begin; j < row_end; ++j) {
        os << ' ' << M.m_host_matrix.entries(j) << ':' << M.m_values[j];
      }
      os << '\n';
    }
    return os;
  }

  CRSMatrix(const FlexibleMatrix& M)
  {
    m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector());
    m_values      = M.valueArray();
  }
  ~CRSMatrix() = default;
};

#endif   // CRS_MATRIX_HPP
