#ifndef Q1_TRANSFORMATION_HPP
#define Q1_TRANSFORMATION_HPP

#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>

#include <array>

template <size_t Dimension>
class Q1Transformation
{
 private:
  constexpr static size_t NumberOfPoints = 1 << Dimension;

  TinyMatrix<Dimension, NumberOfPoints - 1> m_coefficients;
  TinyVector<Dimension> m_shift;

 public:
  PUGS_INLINE
  TinyVector<Dimension>
  operator()(const TinyVector<Dimension>& x) const
  {
    if constexpr (Dimension == 1) {
      return m_coefficients * x + m_shift;
    } else if constexpr (Dimension == 2) {
      const TinyVector<NumberOfPoints - 1> X = {x[0], x[1], x[0] * x[1]};
      return m_coefficients * X + m_shift;
    } else {
      static_assert(Dimension == 3, "invalid dimension");
      const TinyVector<NumberOfPoints - 1> X =
        {x[0], x[1], x[2], x[0] * x[1], x[1] * x[2], x[0] * x[2], x[0] * x[1] * x[2]};
      return m_coefficients * X + m_shift;
    }
  }

  PUGS_INLINE double
  jacobianDeterminant([[maybe_unused]] const TinyVector<Dimension>& X) const
  {
    if constexpr (Dimension == 1) {
      return m_coefficients(0, 0);
    } else if constexpr (Dimension == 2) {
      const auto& T   = m_coefficients;
      const double& x = X[0];
      const double& y = X[1];

      const TinyMatrix<Dimension, Dimension> J = {T(0, 0) + T(0, 2) * y, T(0, 1) + T(0, 2) * x,   //
                                                  T(1, 0) + T(1, 2) * y, T(1, 1) + T(1, 2) * x};
      return det(J);
    } else {
      static_assert(Dimension == 3, "invalid dimension");
      const auto& T   = m_coefficients;
      const double& x = X[0];
      const double& y = X[1];
      const double& z = X[2];

      const TinyMatrix<Dimension, Dimension> J = {T(0, 0) + T(0, 3) * y + T(0, 5) * z + T(0, 6) * y * z,
                                                  T(0, 1) + T(0, 3) * x + T(0, 4) * z + T(0, 6) * x * y,
                                                  T(0, 2) + T(0, 4) * y + T(0, 5) * x + T(0, 6) * x * y,
                                                  //
                                                  T(1, 0) + T(1, 3) * y + T(1, 5) * z + T(1, 6) * y * z,
                                                  T(1, 1) + T(1, 3) * x + T(1, 4) * z + T(1, 6) * x * y,
                                                  T(1, 2) + T(1, 4) * y + T(1, 5) * x + T(1, 6) * x * y,
                                                  //
                                                  T(2, 0) + T(2, 3) * y + T(2, 5) * z + T(2, 6) * y * z,
                                                  T(2, 1) + T(2, 3) * x + T(2, 4) * z + T(2, 6) * x * y,
                                                  T(2, 2) + T(2, 4) * y + T(2, 5) * x + T(2, 6) * x * y};
      return det(J);
    }
  }

  PUGS_INLINE
  Q1Transformation(const std::array<TinyVector<Dimension>, NumberOfPoints>& image_nodes)
  {
    static_assert(Dimension >= 1 and Dimension <= 3, "invalid dimension");
    if constexpr (Dimension == 1) {
      const TinyVector<Dimension>& a = image_nodes[0];
      const TinyVector<Dimension>& b = image_nodes[1];

      m_coefficients(0, 0) = 0.5 * (b[0] - a[0]);

      m_shift[0] = 0.5 * (a[0] + b[0]);
    } else if constexpr (Dimension == 2) {
      const TinyVector<Dimension>& a = image_nodes[0];
      const TinyVector<Dimension>& b = image_nodes[1];
      const TinyVector<Dimension>& c = image_nodes[2];
      const TinyVector<Dimension>& d = image_nodes[3];

      for (size_t i = 0; i < Dimension; ++i) {
        m_coefficients(i, 0) = 0.25 * (-a[i] + b[i] + c[i] - d[i]);
        m_coefficients(i, 1) = 0.25 * (-a[i] - b[i] + c[i] + d[i]);
        m_coefficients(i, 2) = 0.25 * (+a[i] - b[i] + c[i] - d[i]);

        m_shift[i] = 0.25 * (a[i] + b[i] + c[i] + d[i]);
      }
    } else {
      static_assert(Dimension == 3);

      const TinyVector<Dimension>& a = image_nodes[0];
      const TinyVector<Dimension>& b = image_nodes[1];
      const TinyVector<Dimension>& c = image_nodes[2];
      const TinyVector<Dimension>& d = image_nodes[3];
      const TinyVector<Dimension>& e = image_nodes[4];
      const TinyVector<Dimension>& f = image_nodes[5];
      const TinyVector<Dimension>& g = image_nodes[6];
      const TinyVector<Dimension>& h = image_nodes[7];

      for (size_t i = 0; i < Dimension; ++i) {
        m_coefficients(i, 0) = 0.125 * (-a[i] + b[i] + c[i] - d[i] - e[i] + f[i] + g[i] - h[i]);
        m_coefficients(i, 1) = 0.125 * (-a[i] - b[i] + c[i] + d[i] - e[i] - f[i] + g[i] + h[i]);
        m_coefficients(i, 2) = 0.125 * (-a[i] - b[i] - c[i] - d[i] + e[i] + f[i] + g[i] + h[i]);
        m_coefficients(i, 3) = 0.125 * (+a[i] - b[i] + c[i] - d[i] + e[i] - f[i] + g[i] - h[i]);
        m_coefficients(i, 4) = 0.125 * (+a[i] + b[i] - c[i] - d[i] - e[i] - f[i] + g[i] + h[i]);
        m_coefficients(i, 5) = 0.125 * (+a[i] - b[i] - c[i] + d[i] - e[i] + f[i] + g[i] - h[i]);
        m_coefficients(i, 6) = 0.125 * (-a[i] + b[i] - c[i] + d[i] + e[i] - f[i] + g[i] - h[i]);

        m_shift[i] = 0.125 * (a[i] + b[i] + c[i] + d[i] + e[i] + f[i] + g[i] + h[i]);
      }
    }
  }

  ~Q1Transformation() = default;
};

#endif   // Q1_TRANSFORMATION_HPP
