#ifndef POLYNOMIAL_BASIS_HPP
#define POLYNOMIAL_BASIS_HPP

#include <algebra/TinyVector.hpp>
#include <analysis/Polynomial.hpp>
#include <utils/Messenger.hpp>

enum class BasisType
{
  undefined,
  lagrange,
  canonical,
  taylor
};

template <size_t N>
class PolynomialBasis
{
 private:
  static_assert((N >= 0), "Number of elements in the basis must be non-negative");
  TinyVector<N + 1, Polynomial<N>> m_elements;
  BasisType m_basis_type;
  PUGS_INLINE
  constexpr PolynomialBasis<N>&
  _buildCanonicalBasis()
  {
    for (size_t i = 0; i <= N; i++) {
      TinyVector<N + 1> coefficients(zero);
      coefficients[i] = 1;
      elements()[i]   = Polynomial<N>(coefficients);
    }
    return *this;
  }

  PUGS_INLINE
  constexpr PolynomialBasis<N>&
  _buildTaylorBasis(const double& shift)
  {
    TinyVector<N + 1> coefficients(zero);
    elements()[0] = Polynomial<N>(coefficients);
    elements()[0] += Polynomial<0>(TinyVector<1>{1});
    Polynomial<1> unit(TinyVector<2>{-shift, 1});
    for (size_t i = 1; i <= N; i++) {
      elements()[i] = elements()[i - 1] * unit;
    }
    return *this;
  }

  PUGS_INLINE
  constexpr PolynomialBasis<N>&
  _buildLagrangeBasis(const TinyVector<N + 1>& zeros)
  {
    for (size_t i = 0; i < N; ++i) {
      Assert(zeros[i] < zeros[i + 1], "Interpolation values must be strictly increasing in Lagrange polynomials");
    }
    if constexpr (N == 0) {
      elements()[0] = Polynomial<0>(TinyVector<1>{1});
      return *this;
    } else {
      for (size_t i = 0; i <= N; ++i) {
        TinyVector<N + 1> coefficients(zero);
        elements()[i]                   = Polynomial<N>(coefficients);
        elements()[i].coefficients()[0] = 1;
        for (size_t j = 0; j < N + 1; ++j) {
          if (i == j)
            continue;
          double adim = 1. / (zeros[i] - zeros[j]);
          elements()[i] *= Polynomial<1>{{-zeros[j] * adim, adim}};
        }
      }
      return *this;
    }
  }

 public:
  PUGS_INLINE
  constexpr size_t
  size() const
  {
    return N + 1;
  }

  PUGS_INLINE
  constexpr size_t
  degree() const
  {
    return N;
  }

  PUGS_INLINE
  constexpr BasisType&
  type()
  {
    return m_basis_type;
  }

  PUGS_INLINE
  std::string_view
  displayType()
  {
    switch (m_basis_type) {
    case BasisType::lagrange:
      return "lagrange";
    case BasisType::canonical:
      return "canonical";
    case BasisType::taylor:
      return "taylor";
    case BasisType::undefined:
      return "undefined";
    default:
      return "unknown basis type";
    }
  }

  PUGS_INLINE
  constexpr const TinyVector<N + 1, Polynomial<N>>&
  elements() const
  {
    return m_elements;
  }

  PUGS_INLINE
  constexpr TinyVector<N + 1, Polynomial<N>>&
  elements()
  {
    return m_elements;
  }

  PUGS_INLINE
  constexpr PolynomialBasis<N>&
  build(BasisType basis_type, const double& shift = 0, const TinyVector<N + 1>& zeros = TinyVector<N + 1>(zero))
  {
    type() = basis_type;
    switch (basis_type) {
    case BasisType::lagrange: {
      return _buildLagrangeBasis(zeros);
      break;
    }
    case BasisType::canonical: {
      return _buildCanonicalBasis();
      break;
    }
    case BasisType::taylor: {
      return _buildTaylorBasis(shift);
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("unknown basis type");
    }
      // LCOV_EXCL_STOP
    }
  }

  PUGS_INLINE
  constexpr PolynomialBasis() noexcept : m_basis_type{BasisType::undefined} {}

  ~PolynomialBasis() = default;
};
#endif   // POLYNOMIAL_BASIS_HPP