#ifndef TINY_MATRIX_HPP
#define TINY_MATRIX_HPP

#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>

#include <utils/Types.hpp>

#include <algebra/TinyVector.hpp>

#include <iostream>

template <size_t N, typename T = double>
class [[nodiscard]] TinyMatrix
{
 public:
  using data_type = T;

 private:
  T m_values[N * N];
  static_assert((N > 0), "TinyMatrix size must be strictly positive");

  PUGS_FORCEINLINE
  constexpr size_t _index(size_t i, size_t j) const noexcept   // LCOV_EXCL_LINE (due to forced inline)
  {
    return i * N + j;
  }

  template <typename... Args>
  PUGS_FORCEINLINE constexpr void _unpackVariadicInput(const T& t, Args&&... args) noexcept
  {
    m_values[N * N - 1 - sizeof...(args)] = t;
    if constexpr (sizeof...(args) > 0) {
      this->_unpackVariadicInput(std::forward<Args>(args)...);
    }
  }

 public:
  PUGS_INLINE
  constexpr size_t dimension() const
  {
    return N * N;
  }

  PUGS_INLINE
  constexpr size_t nbRows() const
  {
    return N;
  }

  PUGS_INLINE
  constexpr size_t nbColumns() const
  {
    return N;
  }

  PUGS_INLINE
  constexpr TinyMatrix operator-() const
  {
    TinyMatrix opposed;
    for (size_t i = 0; i < N * N; ++i) {
      opposed.m_values[i] = -m_values[i];
    }
    return opposed;
  }

  PUGS_INLINE
  constexpr friend TinyMatrix operator*(const T& t, const TinyMatrix& A)
  {
    TinyMatrix B = A;
    return B *= t;
  }

  PUGS_INLINE
  constexpr friend TinyMatrix operator*(const T& t, TinyMatrix&& A)
  {
    return std::move(A *= t);
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator*=(const T& t)
  {
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] *= t;
    }
    return *this;
  }

  PUGS_INLINE
  constexpr TinyMatrix operator*(const TinyMatrix& B) const
  {
    const TinyMatrix& A = *this;
    TinyMatrix AB;
    for (size_t i = 0; i < N; ++i) {
      for (size_t j = 0; j < N; ++j) {
        T sum = A(i, 0) * B(0, j);
        for (size_t k = 1; k < N; ++k) {
          sum += A(i, k) * B(k, j);
        }
        AB(i, j) = sum;
      }
    }
    return AB;
  }

  PUGS_INLINE
  constexpr TinyVector<N, T> operator*(const TinyVector<N, T>& x) const
  {
    const TinyMatrix& A = *this;
    TinyVector<N, T> Ax;
    for (size_t i = 0; i < N; ++i) {
      T sum = A(i, 0) * x[0];
      for (size_t j = 1; j < N; ++j) {
        sum += A(i, j) * x[j];
      }
      Ax[i] = sum;
    }
    return Ax;
  }

  PUGS_INLINE
  constexpr friend std::ostream& operator<<(std::ostream& os, const TinyMatrix& A)
  {
    os << '[';
    for (size_t i = 0; i < N; ++i) {
      os << '(' << A(i, 0);
      for (size_t j = 1; j < N; ++j) {
        os << ',' << A(i, j);
      }
      os << ')';
    }
    os << ']';

    return os;
  }

  PUGS_INLINE
  constexpr bool operator==(const TinyMatrix& A) const
  {
    for (size_t i = 0; i < N * N; ++i) {
      if (m_values[i] != A.m_values[i])
        return false;
    }
    return true;
  }

  PUGS_INLINE
  constexpr bool operator!=(const TinyMatrix& A) const
  {
    return not this->operator==(A);
  }

  PUGS_INLINE
  constexpr TinyMatrix operator+(const TinyMatrix& A) const
  {
    TinyMatrix sum;
    for (size_t i = 0; i < N * N; ++i) {
      sum.m_values[i] = m_values[i] + A.m_values[i];
    }
    return sum;
  }

  PUGS_INLINE
  constexpr TinyMatrix operator+(TinyMatrix&& A) const
  {
    for (size_t i = 0; i < N * N; ++i) {
      A.m_values[i] += m_values[i];
    }
    return std::move(A);
  }

  PUGS_INLINE
  constexpr TinyMatrix operator-(const TinyMatrix& A) const
  {
    TinyMatrix difference;
    for (size_t i = 0; i < N * N; ++i) {
      difference.m_values[i] = m_values[i] - A.m_values[i];
    }
    return difference;
  }

  PUGS_INLINE
  constexpr TinyMatrix operator-(TinyMatrix&& A) const
  {
    for (size_t i = 0; i < N * N; ++i) {
      A.m_values[i] = m_values[i] - A.m_values[i];
    }
    return std::move(A);
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator+=(const TinyMatrix& A)
  {
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] += A.m_values[i];
    }
    return *this;
  }

  PUGS_INLINE
  constexpr void operator+=(const volatile TinyMatrix& A) volatile
  {
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] += A.m_values[i];
    }
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator-=(const TinyMatrix& A)
  {
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] -= A.m_values[i];
    }
    return *this;
  }

  PUGS_INLINE
  constexpr T& operator()(size_t i, size_t j) noexcept(NO_ASSERT)
  {
    Assert((i < N) and (j < N));
    return m_values[_index(i, j)];
  }

  PUGS_INLINE
  constexpr const T& operator()(size_t i, size_t j) const noexcept(NO_ASSERT)
  {
    Assert((i < N) and (j < N));
    return m_values[_index(i, j)];
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator=(ZeroType) noexcept
  {
    static_assert(std::is_arithmetic<T>(), "Cannot assign 'zero' value for non-arithmetic types");
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] = 0;
    }
    return *this;
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator=(IdentityType) noexcept
  {
    static_assert(std::is_arithmetic<T>(), "Cannot assign 'identity' value for non-arithmetic types");
    for (size_t i = 0; i < N; ++i) {
      for (size_t j = 0; j < N; ++j) {
        m_values[_index(i, j)] = (i == j) ? 1 : 0;
      }
    }
    return *this;
  }

  PUGS_INLINE
  constexpr TinyMatrix& operator=(const TinyMatrix& A) noexcept = default;

  PUGS_INLINE
  constexpr TinyMatrix& operator=(TinyMatrix&& A) noexcept = default;

  template <typename... Args>
  PUGS_INLINE constexpr TinyMatrix(const T& t, Args&&... args) noexcept
  {
    static_assert(sizeof...(args) == N * N - 1, "wrong number of parameters");
    this->_unpackVariadicInput(t, std::forward<Args>(args)...);
  }

  // One does not use the '=default' constructor to avoid (unexpected)
  // performances issues
  PUGS_INLINE
  constexpr TinyMatrix() noexcept {}

  PUGS_INLINE
  constexpr TinyMatrix(ZeroType) noexcept
  {
    static_assert(std::is_arithmetic<T>(), "Cannot construct from 'zero' value "
                                           "for non-arithmetic types");
    for (size_t i = 0; i < N * N; ++i) {
      m_values[i] = 0;
    }
  }

  PUGS_INLINE
  constexpr TinyMatrix(IdentityType) noexcept
  {
    static_assert(std::is_arithmetic<T>(), "Cannot construct from 'identity' "
                                           "value for non-arithmetic types");
    for (size_t i = 0; i < N; ++i) {
      for (size_t j = 0; j < N; ++j) {
        m_values[_index(i, j)] = (i == j) ? 1 : 0;
      }
    }
  }

  PUGS_INLINE
  constexpr TinyMatrix(const TinyMatrix&) noexcept = default;

  PUGS_INLINE
  TinyMatrix(TinyMatrix && A) noexcept = default;

  PUGS_INLINE
  ~TinyMatrix() = default;
};

template <size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<N, T>
tensorProduct(const TinyVector<N, T>& x, const TinyVector<N, T>& y)
{
  TinyMatrix<N, T> A;
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = 0; j < N; ++j) {
      A(i, j) = x[i] * y[j];
    }
  }
  return A;
}

template <size_t N, typename T>
PUGS_INLINE constexpr T
det(const TinyMatrix<N, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
  static_assert(std::is_floating_point<T>::value, "determinent for arbitrary dimension N is defined for floating "
                                                  "point types only");
  TinyMatrix<N, T> M = A;

  TinyVector<N, size_t> index;
  for (size_t i = 0; i < N; ++i)
    index[i] = i;

  T determinent = 1;
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = i; j < N; ++j) {
      size_t l       = j;
      const size_t J = index[j];
      for (size_t k = j + 1; k < N; ++k) {
        if (std::abs(M(index[k], i)) > std::abs(M(J, i))) {
          l = k;
        }
      }
      if (l != j) {
        std::swap(index[l], index[j]);
        determinent *= -1;
      }
    }
    const size_t I = index[i];
    const T Mii    = M(I, i);
    if (Mii != 0) {
      const T inv_Mii = 1. / M(I, i);
      for (size_t k = i + 1; k < N; ++k) {
        const size_t K = index[k];
        const T factor = M(K, i) * inv_Mii;
        for (size_t l = i + 1; l < N; ++l) {
          M(K, l) -= factor * M(I, l);
        }
      }
    } else {
      return 0;
    }
  }

  for (size_t i = 0; i < N; ++i) {
    determinent *= M(index[i], i);
  }
  return determinent;
}

template <typename T>
PUGS_INLINE constexpr T
det(const TinyMatrix<1, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
  return A(0, 0);
}

template <typename T>
PUGS_INLINE constexpr T
det(const TinyMatrix<2, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
  return A(0, 0) * A(1, 1) - A(1, 0) * A(0, 1);
}

template <typename T>
PUGS_INLINE constexpr T
det(const TinyMatrix<3, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non-arithmetic types");
  return A(0, 0) * (A(1, 1) * A(2, 2) - A(2, 1) * A(1, 2)) - A(1, 0) * (A(0, 1) * A(2, 2) - A(2, 1) * A(0, 2)) +
         A(2, 0) * (A(0, 1) * A(1, 2) - A(1, 1) * A(0, 2));
}

template <size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<N - 1, T>
getMinor(const TinyMatrix<N, T>& A, size_t I, size_t J)
{
  static_assert(N >= 2, "minor calculation requires at least 2x2 matrices");
  Assert((I < N) and (J < N));
  TinyMatrix<N - 1, T> M;
  for (size_t i = 0; i < I; ++i) {
    for (size_t j = 0; j < J; ++j) {
      M(i, j) = A(i, j);
    }
    for (size_t j = J + 1; j < N; ++j) {
      M(i, j - 1) = A(i, j);
    }
  }
  for (size_t i = I + 1; i < N; ++i) {
    for (size_t j = 0; j < J; ++j) {
      M(i - 1, j) = A(i, j);
    }
    for (size_t j = J + 1; j < N; ++j) {
      M(i - 1, j - 1) = A(i, j);
    }
  }
  return M;
}

template <size_t N, typename T>
PUGS_INLINE constexpr TinyMatrix<N, T> inverse(const TinyMatrix<N, T>& A);

template <typename T>
PUGS_INLINE constexpr TinyMatrix<1, T>
inverse(const TinyMatrix<1, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  TinyMatrix<1, T> A_1(1. / A(0, 0));
  return A_1;
}

template <size_t N, typename T>
PUGS_INLINE constexpr T
cofactor(const TinyMatrix<N, T>& A, size_t i, size_t j)
{
  static_assert(std::is_arithmetic<T>::value, "cofactor is not defined for non-arithmetic types");
  const T sign = ((i + j) % 2) ? -1 : 1;

  return sign * det(getMinor(A, i, j));
}

template <typename T>
PUGS_INLINE constexpr TinyMatrix<2, T>
inverse(const TinyMatrix<2, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  const T determinent     = det(A);
  const T inv_determinent = 1. / determinent;

  TinyMatrix<2, T> A_cofactors_T(A(1, 1), -A(0, 1), -A(1, 0), A(0, 0));
  return A_cofactors_T *= inv_determinent;
}

template <typename T>
PUGS_INLINE constexpr TinyMatrix<3, T>
inverse(const TinyMatrix<3, T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non-arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  const T determinent = det(A);

  TinyMatrix<3, T> A_cofactors_T(cofactor(A, 0, 0), cofactor(A, 1, 0), cofactor(A, 2, 0), cofactor(A, 0, 1),
                                 cofactor(A, 1, 1), cofactor(A, 2, 1), cofactor(A, 0, 2), cofactor(A, 1, 2),
                                 cofactor(A, 2, 2));

  return A_cofactors_T *= 1. / determinent;
}

#endif   // TINYMATRIX_HPP