#include <catch2/catch.hpp>

#include <Kokkos_Core.hpp>

#include <utils/PugsAssert.hpp>
#include <utils/Types.hpp>

#include <algebra/TinyMatrix.hpp>
#include <analysis/Polynomial.hpp>

// Instantiate to ensure full coverage is performed
template class Polynomial<0>;
template class Polynomial<1>;
template class Polynomial<2>;
template class Polynomial<3>;
template class Polynomial<4>;
template class Polynomial<5>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("Polynomial", "[analysis]")
{
  SECTION("construction")
  {
    REQUIRE_NOTHROW(Polynomial<2>{{2, 3, 4}});
  }
  SECTION("equality")
  {
    Polynomial<2> P({2, 3, 4});
    Polynomial<2> Q({2, 3, 4});
    Polynomial<2> S({2, 3, 5});

    REQUIRE(P == Q);
    REQUIRE(P != S);
  }
  SECTION("addition")
  {
    Polynomial<2> P({2, 3, 4});
    Polynomial<2> Q({-1, -3, 2});
    Polynomial<2> S({1, 0, 6});
    Polynomial<3> T({0, 3, 1, -2});
    Polynomial<3> U({2, 6, 5, -2});
    REQUIRE(S == (P + Q));
    REQUIRE((T + P) == U);
  }
  SECTION("difference")
  {
    Polynomial<2> P({2, 3, 4});
    Polynomial<2> Q({3, 4, 5});
    Polynomial<2> D({-1, -1, -1});
    REQUIRE(D == (P - Q));
    Polynomial<3> R({2, 3, 4, 1});
    REQUIRE(D == (P - Q));
    REQUIRE((P - R) == Polynomial<3>({0, 0, 0, -1}));
  }
  SECTION("product_by_scalar")
  {
    Polynomial<2> P({2, 3, 4});
    Polynomial<2> M({6, 9, 12});
    REQUIRE(M == (P * 3));
    REQUIRE(M == (3 * P));
  }
  SECTION("product")
  {
    Polynomial<2> P({2, 3, 4});
    Polynomial<3> Q({1, 2, -1, 1});
    REQUIRE(Polynomial<5>({2, 7, 8, 7, -1, 4}) == (P * Q));
  }
  SECTION("evaluation")
  {
    Polynomial<2> P({2, -3, 4});
    double result = P.evaluate(3);
    REQUIRE(result == 29);
    result = P(3);
    REQUIRE(result == 29);
  }
  SECTION("primitive")
  {
    Polynomial<2> P({2, -3, 4});
    TinyVector<4> coefs = zero;
    Polynomial<3> Q(coefs);
    Q = primitive(P);
    Polynomial<3> R({0, 2, -3. / 2, 4. / 3});
    REQUIRE(Q == R);
  }
  SECTION("integrate")
  {
    Polynomial<2> P({2, -3, 3});
    double xinf   = -1;
    double xsup   = 1;
    double result = integrate(P, xinf, xsup);
    REQUIRE(result == 6);
  }
  SECTION("derivative")
  {
    Polynomial<2> P({2, -3, 3});
    Polynomial<1> Q = derivative(P);
    REQUIRE(Q == Polynomial<1>({-3, 6}));

    Polynomial<0> P2(3);

    Polynomial<0> R(0);
    REQUIRE(derivative(P2) == R);
  }
}