#ifndef SPARSE_MATRIX_DESCRIPTOR_HPP
#define SPARSE_MATRIX_DESCRIPTOR_HPP

#include <utils/Array.hpp>

#include <map>
#include <type_traits>

template <typename DataType = double, typename IndexType = size_t>
class SparseMatrixDescriptor
{
  static_assert(std::is_integral_v<IndexType>, "Index type must be an integral type");
  static_assert(std::is_unsigned_v<IndexType>, "Index type must be unsigned");

 public:
  using data_type  = DataType;
  using index_type = IndexType;

  class SparseRowDescriptor
  {
    friend class SparseMatrixDescriptor;
    std::map<IndexType, DataType> m_id_value_map;

   public:
    IndexType
    numberOfValues() const noexcept
    {
      return m_id_value_map.size();
    }

    const DataType&
    operator[](const IndexType& j) const
    {
      auto i_value = m_id_value_map.find(j);
      Assert(i_value != m_id_value_map.end());
      return i_value->second;
    }

    DataType&
    operator[](const IndexType& j)
    {
      auto i_value = m_id_value_map.find(j);
      if (i_value != m_id_value_map.end()) {
        return i_value->second;
      } else {
        auto [i_inserted, success] = m_id_value_map.insert(std::make_pair(j, 0));
        Assert(success);
        return i_inserted->second;
      }
    }

    friend std::ostream&
    operator<<(std::ostream& os, const SparseRowDescriptor& row)
    {
      for (auto [j, value] : row.m_id_value_map) {
        os << ' ' << static_cast<size_t>(j) << ':' << value;
      }
      return os;
    }

    SparseRowDescriptor() = default;
  };

 private:
  Array<SparseRowDescriptor> m_row_array;

 public:
  PUGS_INLINE
  size_t
  numberOfRows() const noexcept
  {
    return m_row_array.size();
  }

  SparseRowDescriptor&
  row(const IndexType i)
  {
    Assert(i < m_row_array.size());
    return m_row_array[i];
  }

  const SparseRowDescriptor&
  row(const IndexType i) const
  {
    Assert(i < m_row_array.size());
    return m_row_array[i];
  }

  DataType&
  operator()(const IndexType& i, const IndexType& j)
  {
    Assert(i < m_row_array.size());
    Assert(j < m_row_array.size());
    return m_row_array[i][j];
  }

  const DataType&
  operator()(const IndexType& i, const IndexType& j) const
  {
    Assert(i < m_row_array.size());
    Assert(j < m_row_array.size());
    const auto& r = m_row_array[i];   // split to ensure const-ness of call
    return r[j];
  }

  friend std::ostream&
  operator<<(std::ostream& os, const SparseMatrixDescriptor& M)
  {
    for (IndexType i = 0; i < M.m_row_array.size(); ++i) {
      os << static_cast<size_t>(i) << " |" << M.m_row_array[i] << '\n';
    }
    return os;
  }

  auto
  graphVector() const
  {
    std::vector<std::vector<IndexType>> graph_vector(m_row_array.size());
    for (IndexType i_row = 0; i_row < m_row_array.size(); ++i_row) {
      const SparseRowDescriptor& row = m_row_array[i_row];
      for (auto [id, value] : row.m_id_value_map) {
        graph_vector[i_row].push_back(id);
      }
    }
    return graph_vector;
  }

  Array<DataType>
  valueArray() const
  {
    IndexType size = 0;
    for (IndexType i_row = 0; i_row < m_row_array.size(); ++i_row) {
      size += m_row_array[i_row].m_id_value_map.size();
    }

    Array<DataType> values(size);

    IndexType cpt = 0;
    for (IndexType i_row = 0; i_row < m_row_array.size(); ++i_row) {
      const SparseRowDescriptor& row = m_row_array[i_row];
      for (auto [id, value] : row.m_id_value_map) {
        values[cpt++] = value;
      }
    }
    return values;
  }

  SparseMatrixDescriptor(size_t nb_row) : m_row_array{nb_row} {}

  ~SparseMatrixDescriptor() = default;
};

#endif   // SPARSE_MATRIX_DESCRIPTOR_HPP