#ifndef DENSE_MATRIX_HPP
#define DENSE_MATRIX_HPP

#include <algebra/TinyMatrix.hpp>
#include <algebra/Vector.hpp>
#include <utils/Array.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/PugsUtils.hpp>
#include <utils/Types.hpp>

#include <iostream>

template <typename DataType>
class DenseMatrix   // LCOV_EXCL_LINE
{
 public:
  using data_type  = DataType;
  using index_type = size_t;

 private:
  size_t m_nb_rows;
  size_t m_nb_columns;
  Array<DataType> m_values;

  static_assert(std::is_same_v<typename decltype(m_values)::index_type, index_type>);
  static_assert(std::is_arithmetic_v<DataType>, "Dense matrices expect arithmetic data");

  // Allows const version to access our data
  friend DenseMatrix<std::add_const_t<DataType>>;

 public:
  PUGS_INLINE
  bool
  isSquare() const noexcept
  {
    return m_nb_rows == m_nb_columns;
  }

  friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
  copy(const DenseMatrix& A) noexcept
  {
    return {A.m_nb_rows, A.m_nb_columns, copy(A.m_values)};
  }

  friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
  transpose(const DenseMatrix& A)
  {
    DenseMatrix<std::remove_const_t<DataType>> A_transpose{A.m_nb_columns, A.m_nb_rows};
    for (size_t i = 0; i < A.m_nb_rows; ++i) {
      for (size_t j = 0; j < A.m_nb_columns; ++j) {
        A_transpose(j, i) = A(i, j);
      }
    }
    return A_transpose;
  }

  friend PUGS_INLINE DenseMatrix
  operator*(const DataType& a, const DenseMatrix& A)
  {
    DenseMatrix<std::remove_const_t<DataType>> aA = copy(A);
    return aA *= a;
  }

  template <typename DataType2>
  PUGS_INLINE Vector<std::remove_const_t<DataType>>
  operator*(const Vector<DataType2>& x) const
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert(m_nb_columns == x.size(), "cannot compute matrix-vector product: incompatible sizes");
    const DenseMatrix& A = *this;
    Vector<std::remove_const_t<DataType>> Ax{m_nb_rows};
    for (size_t i = 0; i < m_nb_rows; ++i) {
      std::remove_const_t<DataType> Axi = A(i, 0) * x[0];
      for (size_t j = 1; j < m_nb_columns; ++j) {
        Axi += A(i, j) * x[j];
      }
      Ax[i] = Axi;
    }
    return Ax;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
  operator*(const DenseMatrix<DataType2>& B) const
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert(m_nb_columns == B.numberOfRows(), "cannot compute matrix product: incompatible sizes");
    const DenseMatrix& A = *this;
    DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.numberOfColumns()};

    for (size_t i = 0; i < m_nb_rows; ++i) {
      for (size_t j = 0; j < B.numberOfColumns(); ++j) {
        std::remove_const_t<DataType> ABij = 0;
        for (size_t k = 0; k < m_nb_columns; ++k) {
          ABij += A(i, k) * B(k, j);
        }
        AB(i, j) = ABij;
      }
    }
    return AB;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix&
  operator/=(const DataType2& a)
  {
    const auto inv_a = 1. / a;
    return (*this) *= inv_a;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix&
  operator*=(const DataType2& a)
  {
    parallel_for(
      m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] *= a; });
    return *this;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix&
  operator-=(const DenseMatrix<DataType2>& B)
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert((m_nb_rows == B.numberOfRows()) and (m_nb_columns == B.numberOfColumns()),
           "cannot substract matrix: incompatible sizes");

    parallel_for(
      m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] -= B.m_values[i]; });
    return *this;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix&
  operator+=(const DenseMatrix<DataType2>& B)
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert((m_nb_rows == B.numberOfRows()) and (m_nb_columns == B.numberOfColumns()), "incompatible matrix sizes");

    parallel_for(
      m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] += B.m_values[i]; });
    return *this;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
  operator+(const DenseMatrix<DataType2>& B) const
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert((m_nb_rows == B.numberOfRows()) and (m_nb_columns == B.numberOfColumns()),
           "cannot compute matrix sum: incompatible sizes");

    DenseMatrix<std::remove_const_t<DataType>> sum{B.numberOfRows(), B.numberOfColumns()};

    parallel_for(
      m_values.size(), PUGS_LAMBDA(index_type i) { sum.m_values[i] = m_values[i] + B.m_values[i]; });

    return sum;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
  operator-(const DenseMatrix<DataType2>& B) const
  {
    static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
                  "incompatible data types");
    Assert((m_nb_rows == B.numberOfRows()) and (m_nb_columns == B.numberOfColumns()),
           "cannot compute matrix difference: incompatible sizes");

    DenseMatrix<std::remove_const_t<DataType>> difference{B.numberOfRows(), B.numberOfColumns()};

    parallel_for(
      m_values.size(), PUGS_LAMBDA(index_type i) { difference.m_values[i] = m_values[i] - B.m_values[i]; });

    return difference;
  }

  PUGS_INLINE
  DataType&
  operator()(index_type i, index_type j) const noexcept(NO_ASSERT)
  {
    Assert(i < m_nb_rows and j < m_nb_columns, "cannot access element: invalid indices");
    return m_values[i * m_nb_columns + j];
  }

  PUGS_INLINE
  size_t
  numberOfRows() const noexcept
  {
    return m_nb_rows;
  }

  PUGS_INLINE
  size_t
  numberOfColumns() const noexcept
  {
    return m_nb_columns;
  }

  PUGS_INLINE void
  fill(const DataType& value) noexcept
  {
    m_values.fill(value);
  }

  PUGS_INLINE DenseMatrix& operator=(ZeroType) noexcept
  {
    m_values.fill(0);
    return *this;
  }

  PUGS_INLINE DenseMatrix& operator=(IdentityType) noexcept(NO_ASSERT)
  {
    Assert(m_nb_rows == m_nb_columns, "identity must be a square matrix");

    m_values.fill(0);
    parallel_for(
      m_nb_rows, PUGS_LAMBDA(const index_type i) { m_values[i * m_nb_rows + i] = 1; });
    return *this;
  }

  template <typename DataType2>
  PUGS_INLINE DenseMatrix&
  operator=(const DenseMatrix<DataType2>& A) noexcept
  {
    // ensures that DataType is the same as source DataType2
    static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
                  "Cannot assign DenseMatrix of different type");
    // ensures that const is not lost through copy
    static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()),
                  "Cannot assign DenseMatrix of const to DenseMatrix of non-const");

    m_nb_rows    = A.m_nb_rows;
    m_nb_columns = A.m_nb_columns;
    m_values     = A.m_values;
    return *this;
  }

  PUGS_INLINE
  DenseMatrix& operator=(const DenseMatrix&) = default;

  PUGS_INLINE
  DenseMatrix& operator=(DenseMatrix&&) = default;

  friend std::ostream&
  operator<<(std::ostream& os, const DenseMatrix& A)
  {
    for (size_t i = 0; i < A.numberOfRows(); ++i) {
      os << i << '|';
      for (size_t j = 0; j < A.numberOfColumns(); ++j) {
        os << ' ' << j << ':' << NaNHelper(A(i, j));
      }
      os << '\n';
    }
    return os;
  }

  template <typename DataType2>
  DenseMatrix(const DenseMatrix<DataType2>& A)
  {
    // ensures that DataType is the same as source DataType2
    static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
                  "Cannot assign DenseMatrix of different type");
    // ensures that const is not lost through copy
    static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()),
                  "Cannot assign DenseMatrix of const to DenseMatrix of non-const");

    this->operator=(A);
  }

  DenseMatrix(const DenseMatrix&) = default;

  DenseMatrix(DenseMatrix&&) = default;

  explicit DenseMatrix(size_t nb_rows, size_t nb_columns) noexcept
    : m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_values{nb_rows * nb_columns}
  {}

  explicit DenseMatrix(size_t nb_rows) noexcept : m_nb_rows{nb_rows}, m_nb_columns{nb_rows}, m_values{nb_rows * nb_rows}
  {}

  template <size_t N>
  explicit DenseMatrix(const TinyMatrix<N, DataType>& M) noexcept : m_nb_rows{N}, m_nb_columns{N}, m_values{N * N}
  {
    parallel_for(
      N, PUGS_LAMBDA(const index_type i) {
        for (size_t j = 0; j < N; ++j) {
          m_values[i * N + j] = M(i, j);
        }
      });
  }

  DenseMatrix() noexcept : m_nb_rows{0}, m_nb_columns{0} {}

 private:
  DenseMatrix(size_t nb_rows, size_t nb_columns, const Array<DataType> values) noexcept(NO_ASSERT)
    : m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_values{values}
  {
    Assert(m_values.size() == m_nb_rows * m_nb_columns, "cannot create matrix: incorrect number of values");
  }

 public:
  ~DenseMatrix() = default;
};

#endif   // DENSE_MATRIX_HPP