#include <catch2/catch_test_macros.hpp>

#include <Kokkos_Core.hpp>

#include <utils/PugsAssert.hpp>
#include <utils/Types.hpp>

#include <algebra/TinyMatrix.hpp>
#include <analysis/Polynomial.hpp>

// Instantiate to ensure full coverage is performed
template class Polynomial<0>;
template class Polynomial<1>;
template class Polynomial<2>;
template class Polynomial<3>;
template class Polynomial<4>;
template class Polynomial<5>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("Polynomial", "[analysis]")
{
  SECTION("construction")
  {
    REQUIRE_NOTHROW(Polynomial<2>(2, 3, 4));
  }

  SECTION("degree")
  {
    Polynomial<2> P(2, 3, 4);
    REQUIRE(P.degree() == 2);
  }

  SECTION("equality")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<2> Q(2, 3, 4);
    Polynomial<2> S(2, 3, 5);

    REQUIRE(P == Q);
    REQUIRE(P != S);
  }

  SECTION("addition")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<2> Q(-1, -3, 2);
    Polynomial<2> S(1, 0, 6);
    Polynomial<3> T(0, 3, 1, -2);
    Polynomial<3> U(2, 6, 5, -2);
    REQUIRE(S == (P + Q));
    REQUIRE((T + P) == U);
  }

  SECTION("opposed")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<2> Q = -P;
    REQUIRE(Q == Polynomial<2>(-2, -3, -4));
  }

  SECTION("difference")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<2> Q(3, 4, 5);
    Polynomial<2> D(-1, -1, -1);
    REQUIRE(D == (P - Q));
    Polynomial<3> R(2, 3, 4, 1);
    REQUIRE(D == (P - Q));
    REQUIRE((P - R) == Polynomial<3>{0, 0, 0, -1});
    R -= P;
    REQUIRE(R == Polynomial<3>(0, 0, 0, 1));
  }

  SECTION("product_by_scalar")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<2> M(6, 9, 12);
    REQUIRE(M == (P * 3));
    REQUIRE(M == (3 * P));
  }

  SECTION("product")
  {
    Polynomial<2> P(2, 3, 4);
    Polynomial<3> Q(1, 2, -1, 1);
    Polynomial<4> R;
    Polynomial<5> S;
    R = P;
    S = P;
    S *= Q;
    REQUIRE(Polynomial<5>(2, 7, 8, 7, -1, 4) == (P * Q));
    REQUIRE(Polynomial<5>(2, 7, 8, 7, -1, 4) == S);
    // REQUIRE_THROWS_AS(R *= Q, AssertError);
  }

  SECTION("divide")
  {
    Polynomial<2> P(1, 0, 1);
    Polynomial<1> Q(0, 1);
    Polynomial<1> Q1(0, 1);

    Polynomial<2> R;
    Polynomial<2> S;
    REQUIRE(P.realDegree() == 2);
    REQUIRE(Q.realDegree() == 1);
    REQUIRE(Q1.realDegree() == 1);

    divide(P, Q1, R, S);
    REQUIRE(Polynomial<2>{1, 0, 0} == S);
    REQUIRE(Polynomial<2>{0, 1, 0} == R);
  }

  SECTION("evaluation")
  {
    Polynomial<2> P(2, -3, 4);
    REQUIRE(P(3) == 29);
  }

  SECTION("primitive")
  {
    Polynomial<2> P(2, -3, 4);
    TinyVector<4> coefs = zero;
    Polynomial<3> Q(coefs);
    Q = primitive(P);
    Polynomial<3> R(0, 2, -3. / 2, 4. / 3);
    REQUIRE(Q == R);
  }

  SECTION("integrate")
  {
    Polynomial<2> P(2, -3, 3);
    double xinf   = -1;
    double xsup   = 1;
    double result = integrate(P, xinf, xsup);
    REQUIRE(result == 6);
    result = symmetricIntegrate(P, 2);
    REQUIRE(result == 24);
  }

  SECTION("derivative")
  {
    Polynomial<2> P(2, -3, 3);
    Polynomial<1> Q = derivative(P);
    REQUIRE(Q == Polynomial<1>(-3, 6));

    Polynomial<0> P2(3);

    Polynomial<0> R(0);
    REQUIRE(derivative(P2) == R);
  }

  SECTION("affectation")
  {
    Polynomial<2> Q(2, -3, 3);
    Polynomial<4> R(2, -3, 3, 0, 0);
    Polynomial<4> P(0, 1, 2, 3, 3);
    P = Q;
    REQUIRE(P == R);
  }

  SECTION("affectation addition")
  {
    Polynomial<2> Q(2, -3, 3);
    Polynomial<4> R(2, -2, 5, 3, 3);
    Polynomial<4> P(0, 1, 2, 3, 3);
    P += Q;
    REQUIRE(P == R);
  }

  SECTION("power")
  {
    Polynomial<2> P(2, -3, 3);
    Polynomial<4> R(4, -12, 21, -18, 9);
    Polynomial<1> Q(0, 2);
    Polynomial<2> S = Q.pow<2>(2);
    REQUIRE(P.pow<2>(2) == R);
    REQUIRE(S == Polynomial<2>(0, 0, 4));
  }

  SECTION("composition")
  {
    Polynomial<2> P(2, -3, 3);
    Polynomial<1> Q(0, 2);
    Polynomial<2> R(2, -1, 3);
    Polynomial<2> S(1, 2, 2);
    REQUIRE(P.compose(Q) == Polynomial<2>(2, -6, 12));
    REQUIRE(P.compose2(Q) == Polynomial<2>(2, -6, 12));
    REQUIRE(R(S) == Polynomial<4>(4, 10, 22, 24, 12));
  }

  SECTION("Lagrange polynomial")
  {
    Polynomial<1> S(0.5, -0.5);
    Polynomial<1> Q;
    Q = lagrangePolynomial<1>(TinyVector<2>{-1, 1}, 0);
    REQUIRE(S == Q);
    Polynomial<2> P(0, -0.5, 0.5);
    Polynomial<2> R;
    R = lagrangePolynomial<2>(TinyVector<3>{-1, 0, 1}, 0);
    REQUIRE(R == P);
    const std::array<Polynomial<2>, 3> basis = lagrangeBasis(TinyVector<3>{-1, 0, 1});
    REQUIRE(lagrangeToCanonical(TinyVector<3>{1, 0, 1}, basis) == Polynomial<2>(TinyVector<3>{0, 0, 1}));
  }
}