#include <catch2/catch.hpp>

#include <algebra/CRSMatrix.hpp>

// Instantiate to ensure full coverage is performed
template class SparseMatrixDescriptor<int, uint8_t>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("CRSMatrix", "[algebra]")
{
  SECTION("matrix size")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 2) = 11;
    S(1, 1) = 1;
    S(3, 0) = 5;
    S(3, 1) = -3;
    S(4, 1) = 1;
    S(2, 2) = 4;
    S(4, 3) = 2;
    S(4, 4) = -2;

    CRSMatrix<int, uint8_t> A{S};

    REQUIRE(A.numberOfRows() == S.numberOfRows());
  }

  SECTION("matrix vector product (simple)")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 2) = 11;
    S(1, 1) = 1;
    S(3, 0) = 5;
    S(3, 1) = -3;
    S(4, 1) = 1;
    S(2, 2) = 4;
    S(4, 3) = 2;
    S(4, 4) = -2;

    CRSMatrix<int, uint8_t> A{S};

    Vector<int> x{A.numberOfRows()};
    x    = 0;
    x[0] = 1;

    Vector<int> y = A * x;
    REQUIRE(y[0] == 0);
    REQUIRE(y[1] == 0);
    REQUIRE(y[2] == 0);
    REQUIRE(y[3] == 5);
    REQUIRE(y[4] == 0);

    x    = 0;
    x[1] = 2;

    y = A * x;
    REQUIRE(y[0] == 0);
    REQUIRE(y[1] == 2);
    REQUIRE(y[2] == 0);
    REQUIRE(y[3] == -6);
    REQUIRE(y[4] == 2);

    x    = 0;
    x[2] = -1;

    y = A * x;
    REQUIRE(y[0] == -3);
    REQUIRE(y[1] == -11);
    REQUIRE(y[2] == -4);
    REQUIRE(y[3] == 0);
    REQUIRE(y[4] == 0);

    x    = 0;
    x[3] = 3;

    y = A * x;
    REQUIRE(y[0] == 0);
    REQUIRE(y[1] == 0);
    REQUIRE(y[2] == 0);
    REQUIRE(y[3] == 0);
    REQUIRE(y[4] == 6);

    x    = 0;
    x[4] = 1;

    y = A * x;
    REQUIRE(y[0] == 0);
    REQUIRE(y[1] == 0);
    REQUIRE(y[2] == 0);
    REQUIRE(y[3] == 0);
    REQUIRE(y[4] == -2);
  }

  SECTION("matrix vector product (complet)")
  {
    SparseMatrixDescriptor<int, uint8_t> S{4};
    S(0, 0) = 1;
    S(0, 1) = 2;
    S(0, 2) = 3;
    S(0, 3) = 4;
    S(1, 0) = 5;
    S(1, 1) = 6;
    S(1, 2) = 7;
    S(1, 3) = 8;
    S(2, 0) = 9;
    S(2, 1) = 10;
    S(2, 2) = 11;
    S(2, 3) = 12;
    S(3, 0) = 13;
    S(3, 1) = 14;
    S(3, 2) = 15;
    S(3, 3) = 16;

    CRSMatrix<int, uint8_t> A{S};

    Vector<int> x{A.numberOfRows()};
    x[0] = 1;
    x[1] = 2;
    x[2] = 3;
    x[3] = 4;

    Vector<int> y = A * x;
    REQUIRE(y[0] == 30);
    REQUIRE(y[1] == 70);
    REQUIRE(y[2] == 110);
    REQUIRE(y[3] == 150);
  }

#ifndef NDEBUG
  SECTION("incompatible runtime matrix/vector product")
  {
    CRSMatrix<int, uint8_t> A{SparseMatrixDescriptor<int, uint8_t>{4}};
    Vector<int> x{2};
    REQUIRE_THROWS_AS(A * x, AssertError);
  }
#endif   // NDEBUG
}
