#ifndef TINY_MATRIX_HPP
#define TINY_MATRIX_HPP

#include <PastisMacros.hpp>
#include <PastisAssert.hpp>

#include <Types.hpp>
#include <TinyVector.hpp>

#include <iostream>

template <size_t N, typename T=double>
class TinyMatrix
{
private:
  T m_values[N*N];
  static_assert((N>0),"TinyMatrix size must be strictly positive");

  PASTIS_FORCEINLINE
  constexpr size_t _index(const size_t& i, const size_t& j) const noexcept
  {
    return std::move(i*N+j);
  }

  template <typename... Args>
  PASTIS_FORCEINLINE
  constexpr void _unpackVariadicInput(const T& t, Args&&... args) noexcept
  {
    m_values[N*N-1-sizeof...(args)] = t;
    if constexpr (sizeof...(args) >0) {
        this->_unpackVariadicInput(args...);
      }
  }

public:
  PASTIS_INLINE
  constexpr TinyMatrix operator-() const
  {
    TinyMatrix opposed{no_init};
    for (size_t i=0; i<N*N; ++i) {
      opposed.m_values[i] = -m_values[i];
    }
    return std::move(opposed);
  }

  PASTIS_INLINE
  constexpr friend TinyMatrix operator*(const T& t, const TinyMatrix& A)
  {
    TinyMatrix B = A;
    return std::move(B *= t);
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator*=(const T& t)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] *= t;
    }
    return *this;
  }

  PASTIS_INLINE
  constexpr TinyMatrix operator*(const TinyMatrix& B) const
  {
    const TinyMatrix& A = *this;
    TinyMatrix AB{no_init};
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
        T sum = A(i,0)*B(0,j);
        for (size_t k=1; k<N; ++k) {
          sum += A(i,k)*B(k,j);
        }
        AB(i,j) = sum;
      }
    }
    return std::move(AB);
  }

  PASTIS_INLINE
  constexpr TinyVector<N,T> operator*(const TinyVector<N,T>& x) const
  {
    const TinyMatrix& A = *this;
    TinyVector<N,T> Ax{no_init};
    for (size_t i=0; i<N; ++i) {
      T sum = A(i,0)*x[0];
      for (size_t j=1; j<N; ++j) {
        sum += A(i,j)*x[j];
      }
      Ax[i] = sum;
    }
    return std::move(Ax);
  }

  PASTIS_INLINE
  constexpr friend std::ostream& operator<<(std::ostream& os, const TinyMatrix& A)
  {
    if constexpr(N==1) {
      os << A(0,0);
      }
    else {
      os << '[';
      for (size_t i=0; i<N; ++i) {
        os << '(' << A(i,0);
        for (size_t j=1; j<N; ++j) {
          os << ',' << A(i,j);
        }
        os << ')';
      }
      os << ']';
    }
    return os;
  }

  PASTIS_INLINE
  constexpr bool operator==(const TinyMatrix& A) const
  {
    for (size_t i=0; i<N*N; ++i) {
      if (m_values[i] != A.m_values[i]) return false;
    }
    return true;
  }

  PASTIS_INLINE
  constexpr bool operator!=(const TinyMatrix& A) const
  {
    return not this->operator==(A);
  }

  PASTIS_INLINE
  constexpr TinyMatrix operator+(const TinyMatrix& A) const
  {
    TinyMatrix sum{no_init};
    for (size_t i=0; i<N*N; ++i) {
      sum.m_values[i] = m_values[i]+A.m_values[i];
    }
    return std::move(sum);
  }

  PASTIS_INLINE
  constexpr TinyMatrix operator+(TinyMatrix&& A) const
  {
    for (size_t i=0; i<N*N; ++i) {
      A.m_values[i] += m_values[i];
    }
    return std::move(A);
  }

  PASTIS_INLINE
  constexpr TinyMatrix operator-(const TinyMatrix& A) const
  {
    TinyMatrix difference{no_init};
    for (size_t i=0; i<N*N; ++i) {
      difference.m_values[i] = m_values[i]-A.m_values[i];
    }
    return std::move(difference);
  }

  PASTIS_INLINE
  constexpr TinyMatrix operator-(TinyMatrix&& A) const
  {
    for (size_t i=0; i<N*N; ++i) {
      A.m_values[i] = m_values[i]-A.m_values[i];
    }
    return std::move(A);
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator+=(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] += A.m_values[i];
    }
    return *this;
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator-=(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] -= A.m_values[i];
    }
    return *this;
  }

  PASTIS_INLINE
  constexpr T& operator()(const size_t& i, const size_t& j) noexcept(NO_ASSERT)
  {
    Assert((i<N) and (j<N));
    return m_values[_index(i,j)];
  }

  PASTIS_INLINE
  constexpr const T& operator()(const size_t& i, const size_t& j) const noexcept(NO_ASSERT)
  {
    Assert((i<N) and (j<N));
    return m_values[_index(i,j)];
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator=(const ZeroType& z) noexcept
  {
    static_assert(std::is_arithmetic<T>(),"Cannot assign 'zero' value for non-arithmetic types");
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = 0;
    }
    return *this;
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator=(const IdentityType& I) noexcept
  {
    static_assert(std::is_arithmetic<T>(),"Cannot assign 'identity' value for non-arithmetic types");
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
        m_values[_index(i,j)] = (i==j) ? 1 : 0;
      }
    }
    return *this;
  }

  PASTIS_INLINE
  constexpr TinyMatrix& operator=(const TinyMatrix& A) noexcept = default;

  PASTIS_INLINE
  constexpr TinyMatrix& operator=(TinyMatrix&& A) noexcept = default;

  template <typename... Args>
  PASTIS_INLINE
  constexpr TinyMatrix(const T& t, Args&&... args) noexcept
  {
    static_assert(sizeof...(args)==N*N-1, "wrong number of parameters");
    this->_unpackVariadicInput(t, args...);
  }

  PASTIS_INLINE
  constexpr TinyMatrix() noexcept = default;

  PASTIS_INLINE
  constexpr TinyMatrix(const ZeroType& z) noexcept
  {
    static_assert(std::is_arithmetic<T>(),"Cannot construct from 'zero' value for non-arithmetic types");
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = 0;
    }
  }

  PASTIS_INLINE
  constexpr TinyMatrix(const IdentityType& I) noexcept
  {
    static_assert(std::is_arithmetic<T>(),"Cannot construct from 'identity' value for non-arithmetic types");
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
        m_values[_index(i,j)] = (i==j) ? 1 : 0;
      }
    }
  }

  // Default constructor helper. Used to define uninitialized TinyMatrix
  // avoiding compiler's false uninitialized warning.
  PASTIS_INLINE
  constexpr TinyMatrix(const NoInitType&) noexcept {}

  PASTIS_INLINE
  constexpr TinyMatrix(const TinyMatrix&) noexcept = default;

  PASTIS_INLINE
  TinyMatrix(TinyMatrix&& A) noexcept = default;

  PASTIS_INLINE
  ~TinyMatrix() = default;
};

template <size_t N, typename T>
PASTIS_INLINE
constexpr TinyMatrix<N,T> tensorProduct(const TinyVector<N,T>& x,
                                        const TinyVector<N,T>& y)
{
  TinyMatrix<N,T> A{no_init};
  for (size_t i=0; i<N; ++i) {
    for (size_t j=0; j<N; ++j) {
      A(i,j) = x[i]*y[j];
    }
  }
  return std::move(A);
}

template <size_t N, typename T>
PASTIS_INLINE
constexpr T det(const TinyMatrix<N,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  static_assert(std::is_floating_point<T>::value, "determinent for arbitrary dimension N is defined for floating point types only");
  TinyMatrix<N,T> M = A;

  TinyVector<N, size_t> index{no_init};
  for (size_t i=0; i<N; ++i) index[i]=i;

  T determinent = 1;
  for (size_t i=0; i<N; ++i) {
    for (size_t j=i; j<N; ++j) {
      size_t l = j;
      const size_t J = index[j];
      for (size_t k=j+1; k<N; ++k) {
        if (std::abs(M(index[k],i)) > std::abs(M(J,i))) {
          l=k;
        }
      }
      if (l != j) {
        std::swap(index[l], index[j]);
        determinent *= -1;
      }
    }
    const size_t I = index[i];
    if (M(I,i)==0) return 0;
    const T inv_Mii = 1./M(I,i);
    for (size_t k=i+1; k<N; ++k) {
      const size_t K = index[k];
      const T factor = M(K,i)*inv_Mii;
      for (size_t l=i+1; l<N; ++l) {
        M(K,l) -= factor*M(I,l);
      }
    }
  }

  for (size_t i=0; i<N; ++i) {
    determinent *= M(index[i],i);
  }
  return determinent;
}

template <typename T>
PASTIS_INLINE
constexpr T det(const TinyMatrix<1,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  return A(0,0);
}

template <typename T>
PASTIS_INLINE
constexpr T det(const TinyMatrix<2,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  return A(0,0)*A(1,1)-A(1,0)*A(0,1);
}

template <typename T>
PASTIS_INLINE
constexpr T det(const TinyMatrix<3,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  return
    A(0,0)*(A(1,1)*A(2,2)-A(2,1)*A(1,2))
    -A(1,0)*(A(0,1)*A(2,2)-A(2,1)*A(0,2))
    +A(2,0)*(A(0,1)*A(1,2)-A(1,1)*A(0,2));
}

template <size_t N, typename T>
PASTIS_INLINE
constexpr TinyMatrix<N-1,T> getMinor(const TinyMatrix<N,T>& A,
                                     const size_t& I,
                                     const size_t& J)
{
  static_assert(N>=2, "minor calculation requires at least 2x2 matrices");
  Assert((I<N) and (J<N));
  TinyMatrix<N-1,T> M{no_init};
  for (size_t i=0; i<I; ++i) {
    for (size_t j=0; j<J; ++j) {
      M(i,j)=A(i,j);
    }
    for (size_t j=J+1; j<N; ++j) {
      M(i,j-1)=A(i,j);
    }
  }
  for (size_t i=I+1; i<N; ++i) {
    for (size_t j=0; j<J; ++j) {
      M(i-1,j)=A(i,j);
    }
    for (size_t j=J+1; j<N; ++j) {
      M(i-1,j-1)=A(i,j);
    }
  }
  return std::move(M);
}

template <size_t N, typename T>
PASTIS_INLINE
constexpr TinyMatrix<N,T> inverse(const TinyMatrix<N,T>& A);

template <typename T>
PASTIS_INLINE
constexpr TinyMatrix<1,T> inverse(const TinyMatrix<1,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  TinyMatrix<1,T> A_1(1./A(0,0));
  return std::move(A_1);
}

template <size_t N, typename T>
PASTIS_INLINE
constexpr T cofactor(const TinyMatrix<N,T>& A,
                     const size_t& i,
                     const size_t& j)
{
  static_assert(std::is_arithmetic<T>::value, "cofactor is not defined for non arithmetic types");
  const T sign = ((i+j)%2) ? -1 : 1;

  return sign * det(getMinor(A, i, j));
}

template <typename T>
PASTIS_INLINE
constexpr TinyMatrix<2,T> inverse(const TinyMatrix<2,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  const T determinent = det(A);
  const T inv_determinent = 1./determinent;

  TinyMatrix<2,T> A_cofactors_T(A(1,1), -A(0,1),
                                -A(1,0), A(0,0));
  return std::move(A_cofactors_T *= inv_determinent);
}


template <typename T>
PASTIS_INLINE
constexpr TinyMatrix<3,T> inverse(const TinyMatrix<3,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "inverse is not defined for non arithmetic types");
  static_assert(std::is_floating_point<T>::value, "inverse is defined for floating point types only");

  const T determinent = det(A);

  TinyMatrix<3,T> A_cofactors_T(cofactor(A,0,0), cofactor(A,1,0), cofactor(A,2,0),
                                cofactor(A,0,1), cofactor(A,1,1), cofactor(A,2,1),
                                cofactor(A,0,2), cofactor(A,1,2), cofactor(A,2,2));

  return std::move(A_cofactors_T *= 1./determinent);
}

#endif // TINYMATRIX_HPP
