#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>

#include <utils/PugsAssert.hpp>

#include <algebra/TinyVector.hpp>

#include <sstream>

// Instantiate to ensure full coverage is performed
template class TinyVector<1, int>;
template class TinyVector<3, int>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("TinyVector", "[algebra]")
{
  TinyVector<3, int> v(1, 2, 3);
  REQUIRE(((v[0] == 1) and (v[1] == 2) and (v[2] == 3)));
  REQUIRE(-v == TinyVector<3, int>(-1, -2, -3));

  REQUIRE(v.dimension() == 3);

  const TinyVector<3, int> z = zero;
  REQUIRE(((z[0] == 0) and (z[1] == 0) and (z[2] == 0)));

  v = TinyVector<3, int>(3, 2, 4);
  REQUIRE(Catch::Detail::stringify(v) == "(3,2,4)");

  TinyVector<3, int> w(1, 2, 6);
  REQUIRE(dot(v, w) == 31);

  w = 2 * v;
  REQUIRE(w == TinyVector<3, int>(6, 4, 8));

  TinyVector<3, int> x = v;
  REQUIRE(x == v);

  x = TinyVector<3, int>(6, 4, 8);
  REQUIRE(x == w);
  REQUIRE_FALSE(x == v);
  REQUIRE(x != v);
  REQUIRE_FALSE(x != w);

  x = v;
  REQUIRE(x == v);

  x += w;
  REQUIRE(x == 3 * v);

  x = v + w;
  REQUIRE(x == 3 * v);

  x = 2 * (v + v);
  REQUIRE(x == 4 * v);

  x = v + (w + x);
  REQUIRE(x == 7 * v);

  x = x - (v + w);
  REQUIRE(x == 4 * v);

  x -= v + w;
  REQUIRE(x == v);

  x = w - v;
  REQUIRE(x == v);

  x = v - 2 * v;
  REQUIRE(x == -v);

  TinyVector<3, int> z1;
  z1 = zero;
  REQUIRE(((z1[0] == 0) and (z1[1] == 0) and (z1[2] == 0)));

  REQUIRE(l2Norm(TinyVector<2, double>(3, 4)) == Catch::Approx(5).epsilon(1E-14));

  SECTION("checking for cross product")
  {
    const TinyVector<3, int> a(1, -2, 4);
    const TinyVector<3, int> b(3, 1, 6);
    REQUIRE(crossProduct(a, b) == TinyVector<3, int>(-16, 6, 7));
  }

#ifndef NDEBUG
  SECTION("output with signaling NaN")
  {
    TinyVector<3> x;
    x[1] = 1;
    std::ostringstream x_ost;
    x_ost << x;

    std::ostringstream ref_ost;
    ref_ost << "(nan,1,nan)";

    REQUIRE(x_ost.str() == ref_ost.str());
  }

  SECTION("checking for bounds validation")
  {
    REQUIRE_THROWS_AS(x[4] = 0, AssertError);
    const TinyVector<3, int>& const_x = x;
    REQUIRE_THROWS_AS(const_x[-1], AssertError);
  }

  SECTION("checking for nan initialization")
  {
    TinyVector<3, double> y;

    for (size_t i = 0; i < y.dimension(); ++i) {
      REQUIRE(std::isnan(y[i]));
    }
  }

  SECTION("checking for bad initialization")
  {
    TinyVector<3, int> y;

    for (size_t i = 0; i < y.dimension(); ++i) {
      REQUIRE(y[i] == std::numeric_limits<int>::max() / 2);
    }
  }

#endif   // NDEBUG
}