#ifndef TINY_MATRIX_HPP
#define TINY_MATRIX_HPP

#include <cassert>
#include <iostream>

#include <Types.hpp>
#include <TinyVector.hpp>

template <size_t N, typename T=double>
class TinyMatrix
{
private:
  T m_values[N*N];
  static_assert((N>0),"TinyMatrix size must be strictly positive");

  KOKKOS_INLINE_FUNCTION
  size_t _index(const size_t& i, const size_t& j) const
  {
    return std::move(i*N+j);
  }

  void _unpackVariadicInput(const T& t)
  {
    m_values[N*N-1] = t;
  }

  template <typename... Args>
  void _unpackVariadicInput(const T& t, Args&&... args)
  {
    m_values[N*N-1-sizeof...(args)] = t;
    this->_unpackVariadicInput(args...);
  }

public:

  KOKKOS_INLINE_FUNCTION
  friend TinyMatrix operator*(const T& t, const TinyMatrix& A)
  {
    TinyMatrix tA;
    for (size_t i=0; i<N*N; ++i) {
      tA.m_values[i] = t * A.m_values[i];
    }
    return std::move(tA);
  }


  KOKKOS_INLINE_FUNCTION
  TinyMatrix operator*(const TinyMatrix& B) const
  {
    const TinyMatrix& A = *this;
    TinyMatrix AB;
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
	T sum = A(i,0)*B(0,j);
	for (size_t k=1; k<N; ++k) {
	  sum += A(i,k)*B(k,j);
	}
	AB(i,j) = sum;
      }
    }
    return std::move(AB);
  }

  
  KOKKOS_INLINE_FUNCTION
  TinyVector<N,T> operator*(const TinyVector<N,T>& x) const
  {
    const TinyMatrix& A = *this;
    TinyVector<N,T> Ax;
    for (size_t i=0; i<N; ++i) {
      T sum = A(i,0)*x[0];
      for (size_t j=1; j<N; ++j) {
	sum += A(i,j)*x[j];
      }
      Ax[i] = sum;
    }
    return std::move(Ax);
  }


  KOKKOS_INLINE_FUNCTION
  friend std::ostream& operator<<(std::ostream& os, const TinyMatrix& A)
  {
    if constexpr(N==1) {
	os << A(0,0);
      }
    else {
      os << '[';
      for (size_t i=0; i<N; ++i) {
	os << '(' << A(i,0);
	for (size_t j=1; j<N; ++j) {
	  os << ',' << A(i,j);
	}
	os << ')';
      }
      os << ']';
    }
    return os;
  }
  
  KOKKOS_INLINE_FUNCTION
  TinyMatrix operator+(const TinyMatrix& A) const
  {
    TinyMatrix sum;
    for (size_t i=0; i<N*N; ++i) {
      sum.m_values[i] = m_values[i]+A.m_values[i];
    }
    return std::move(sum);
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix operator-(const TinyMatrix& A) const
  {
    TinyMatrix difference;
    for (size_t i=0; i<N*N; ++i) {
      difference.m_values[i] = m_values[i]-A.m_values[i];
    }
    return std::move(difference);
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix& operator+=(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] += A.m_values[i];
    }
    return *this;
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix& operator-=(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] -= A.m_values[i];
    }
    return *this;
  }

  KOKKOS_INLINE_FUNCTION
  T& operator()(const size_t& i, const size_t& j)
  {
    assert((i<N) and (j<N));
    return m_values[_index(i,j)];
  }

  KOKKOS_INLINE_FUNCTION
  const T& operator()(const size_t& i, const size_t& j) const
  {
    assert((i<N) and (j<N));
    return m_values[_index(i,j)];
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix& operator=(const ZeroType& z)
  {
    static_assert(std::is_arithmetic<T>(),"Cannot assign 'zero' value for non-arithmetic types");
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = 0;
    }
    return *this;
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix& operator=(const IdentityType& I)
  {
    static_assert(std::is_arithmetic<T>(),"Cannot assign 'identity' value for non-arithmetic types");
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
	m_values[_index(i,j)] = (i==j) ? 1 : 0;
      }
    }
    return *this;
  }

  KOKKOS_INLINE_FUNCTION
  const TinyMatrix& operator=(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = A.m_values[i];
    }
    return *this;
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix& operator=(TinyMatrix&& A) = default;

  template <typename... Args>
  KOKKOS_INLINE_FUNCTION
  TinyMatrix(const T& t, Args&&... args)
  {
    static_assert(sizeof...(args)==N*N-1, "wrong number of parameters");
    this->_unpackVariadicInput(t, args...);
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix()
  {
    ;
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix(const ZeroType& z)
  {
    static_assert(std::is_arithmetic<T>(),"Cannot construct from 'zero' value for non-arithmetic types");
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = 0;
    }
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix(const IdentityType& I)
  {
    static_assert(std::is_arithmetic<T>(),"Cannot construct from 'identity' value for non-arithmetic types");
    for (size_t i=0; i<N; ++i) {
      for (size_t j=0; j<N; ++j) {
	m_values[_index(i,j)] = (i==j) ? 1 : 0;
      }
    }
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix(const TinyMatrix& A)
  {
    for (size_t i=0; i<N*N; ++i) {
      m_values[i] = A.m_values[i];
    }
  }

  KOKKOS_INLINE_FUNCTION
  TinyMatrix(TinyMatrix&& A) = default;
  
  KOKKOS_INLINE_FUNCTION
  ~TinyMatrix()
  {
    ;
  }
};

template <size_t N, typename T>
KOKKOS_INLINE_FUNCTION
TinyMatrix<N,T> tensorProduct(const TinyVector<N,T>& x,
			      const TinyVector<N,T>& y)
{
  TinyMatrix<N,T> A;
  for (size_t i=0; i<N; ++i) {
    for (size_t j=0; j<N; ++j) {
      A(i,j) = x[i]*y[j];
    }
  }
  return std::move(A);
}

template <size_t N, typename T>
KOKKOS_INLINE_FUNCTION
T det(const TinyMatrix<N,T>& A)
{
  TinyMatrix<N,T> M=A;
#warning must code a Gauss pivoting approach and compare perfs with Camers for 3x3 matrices
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  return 0;
}

template <typename T>
KOKKOS_INLINE_FUNCTION
T det(const TinyMatrix<1,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  TinyMatrix<1,T> M=A;
  
  return A(0,0);
}

template <typename T>
KOKKOS_INLINE_FUNCTION
T det(const TinyMatrix<2,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  TinyMatrix<2,T> M=A;
  
  return A(0,0)*A(1,1)-A(1,0)*A(0,1);
}

template <typename T>
KOKKOS_INLINE_FUNCTION
T det(const TinyMatrix<3,T>& A)
{
  static_assert(std::is_arithmetic<T>::value, "determinent is not defined for non arithmetic types");
  TinyMatrix<3,T> M=A;
  
  return
    A(0,0)*(A(1,1)*A(2,2)-A(1,2)*A(2,1))
    -A(1,0)*(A(0,1)*A(2,2)-A(0,2)*A(2,1))
    +A(2,0)*(A(1,2)*A(0,1)-A(1,1)*A(0,2));
}

#endif // TINYMATRIX_HPP