#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <utils/PugsAssert.hpp>

#include <algebra/DenseMatrix.hpp>
#include <algebra/Vector.hpp>

// Instantiate to ensure full coverage is performed
template class DenseMatrix<int>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("DenseMatrix", "[algebra]")
{
  SECTION("size")
  {
    DenseMatrix<int> A{2, 3};
    REQUIRE(A.nbRows() == 2);
    REQUIRE(A.nbColumns() == 3);
  }

  SECTION("write access")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;
    REQUIRE(A(0, 0) == 0);
    REQUIRE(A(0, 1) == 1);
    REQUIRE(A(0, 2) == 2);
    REQUIRE(A(1, 0) == 3);
    REQUIRE(A(1, 1) == 4);
    REQUIRE(A(1, 2) == 5);
    DenseMatrix<const int> const_A = A;
    REQUIRE(const_A(0, 0) == 0);
    REQUIRE(const_A(0, 1) == 1);
    REQUIRE(const_A(0, 2) == 2);
    REQUIRE(const_A(1, 0) == 3);
    REQUIRE(const_A(1, 1) == 4);
    REQUIRE(const_A(1, 2) == 5);
  }

  SECTION("fill")
  {
    DenseMatrix<int> A{2, 3};
    A.fill(2);
    REQUIRE(A(0, 0) == 2);
    REQUIRE(A(0, 1) == 2);
    REQUIRE(A(0, 2) == 2);
    REQUIRE(A(1, 0) == 2);
    REQUIRE(A(1, 1) == 2);
    REQUIRE(A(1, 2) == 2);
  }

  SECTION("copy constructor (shallow)")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    const DenseMatrix<int> B = A;
    REQUIRE(B(0, 0) == 0);
    REQUIRE(B(0, 1) == 1);
    REQUIRE(B(0, 2) == 2);
    REQUIRE(B(1, 0) == 3);
    REQUIRE(B(1, 1) == 4);
    REQUIRE(B(1, 2) == 5);
  }

  SECTION("copy (deep)")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    const DenseMatrix<int> B = copy(A);

    A(0, 0) = 10;
    A(0, 1) = 11;
    A(0, 2) = 12;
    A(1, 0) = 13;
    A(1, 1) = 14;
    A(1, 2) = 15;

    REQUIRE(B(0, 0) == 0);
    REQUIRE(B(0, 1) == 1);
    REQUIRE(B(0, 2) == 2);
    REQUIRE(B(1, 0) == 3);
    REQUIRE(B(1, 1) == 4);
    REQUIRE(B(1, 2) == 5);

    REQUIRE(A(0, 0) == 10);
    REQUIRE(A(0, 1) == 11);
    REQUIRE(A(0, 2) == 12);
    REQUIRE(A(1, 0) == 13);
    REQUIRE(A(1, 1) == 14);
    REQUIRE(A(1, 2) == 15);
  }

  SECTION("self scalar multiplication")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    A *= 2;
    REQUIRE(A(0, 0) == 0);
    REQUIRE(A(0, 1) == 2);
    REQUIRE(A(0, 2) == 4);
    REQUIRE(A(1, 0) == 6);
    REQUIRE(A(1, 1) == 8);
    REQUIRE(A(1, 2) == 10);
  }

  SECTION("left scalar multiplication")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    const DenseMatrix<int> B = 2 * A;

    REQUIRE(B(0, 0) == 0);
    REQUIRE(B(0, 1) == 2);
    REQUIRE(B(0, 2) == 4);
    REQUIRE(B(1, 0) == 6);
    REQUIRE(B(1, 1) == 8);
    REQUIRE(B(1, 2) == 10);
  }

  SECTION("product matrix vector")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 6;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    Vector<int> x{3};
    x[0] = 7;
    x[1] = 3;
    x[2] = 4;

    Vector y = A * x;
    REQUIRE(y[0] == 53);
    REQUIRE(y[1] == 53);
  }

  SECTION("self scalar division")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 2;
    A(0, 2) = 4;
    A(1, 0) = 6;
    A(1, 1) = 8;
    A(1, 2) = 10;

    A /= 2;

    REQUIRE(A(0, 0) == 0);
    REQUIRE(A(0, 1) == 1);
    REQUIRE(A(0, 2) == 2);
    REQUIRE(A(1, 0) == 3);
    REQUIRE(A(1, 1) == 4);
    REQUIRE(A(1, 2) == 5);
  }

  SECTION("self minus")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    DenseMatrix<int> B{2, 3};
    B(0, 0) = 5;
    B(0, 1) = 6;
    B(0, 2) = 4;
    B(1, 0) = 2;
    B(1, 1) = 1;
    B(1, 2) = 3;

    A -= B;

    REQUIRE(A(0, 0) == -5);
    REQUIRE(A(0, 1) == -5);
    REQUIRE(A(0, 2) == -2);
    REQUIRE(A(1, 0) == 1);
    REQUIRE(A(1, 1) == 3);
    REQUIRE(A(1, 2) == 2);
  }

  SECTION("self sum")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    DenseMatrix<int> B{2, 3};
    B(0, 0) = 5;
    B(0, 1) = 6;
    B(0, 2) = 4;
    B(1, 0) = 2;
    B(1, 1) = 1;
    B(1, 2) = 3;

    A += B;

    REQUIRE(A(0, 0) == 5);
    REQUIRE(A(0, 1) == 7);
    REQUIRE(A(0, 2) == 6);
    REQUIRE(A(1, 0) == 5);
    REQUIRE(A(1, 1) == 5);
    REQUIRE(A(1, 2) == 8);
  }

  SECTION("sum")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 6;
    A(0, 1) = 5;
    A(0, 2) = 4;
    A(1, 0) = 3;
    A(1, 1) = 2;
    A(1, 2) = 1;

    DenseMatrix<int> B{2, 3};
    B(0, 0) = 0;
    B(0, 1) = 1;
    B(0, 2) = 2;
    B(1, 0) = 3;
    B(1, 1) = 4;
    B(1, 2) = 5;

    DenseMatrix C = A + B;
    REQUIRE(C(0, 0) == 6);
    REQUIRE(C(0, 1) == 6);
    REQUIRE(C(0, 2) == 6);
    REQUIRE(C(1, 0) == 6);
    REQUIRE(C(1, 1) == 6);
    REQUIRE(C(1, 2) == 6);
  }

  SECTION("difference")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 6;
    A(0, 1) = 5;
    A(0, 2) = 4;
    A(1, 0) = 3;
    A(1, 1) = 2;
    A(1, 2) = 1;

    DenseMatrix<int> B{2, 3};
    B(0, 0) = 0;
    B(0, 1) = 1;
    B(0, 2) = 2;
    B(1, 0) = 3;
    B(1, 1) = 4;
    B(1, 2) = 5;

    DenseMatrix C = A - B;
    REQUIRE(C(0, 0) == 6);
    REQUIRE(C(0, 1) == 4);
    REQUIRE(C(0, 2) == 2);
    REQUIRE(C(1, 0) == 0);
    REQUIRE(C(1, 1) == -2);
    REQUIRE(C(1, 2) == -4);
  }

  SECTION("transpose")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 0;
    A(0, 1) = 1;
    A(0, 2) = 2;
    A(1, 0) = 3;
    A(1, 1) = 4;
    A(1, 2) = 5;

    DenseMatrix B = transpose(A);
    REQUIRE(B(0, 0) == 0);
    REQUIRE(B(0, 1) == 3);
    REQUIRE(B(1, 0) == 1);
    REQUIRE(B(1, 1) == 4);
    REQUIRE(B(2, 0) == 2);
    REQUIRE(B(2, 1) == 5);
  }

  SECTION("product matrix vector")
  {
    DenseMatrix<int> A{2, 3};
    A(0, 0) = 1;
    A(0, 1) = 2;
    A(0, 2) = 3;
    A(1, 0) = 4;
    A(1, 1) = 5;
    A(1, 2) = 6;

    DenseMatrix<int> B{3, 2};
    B(0, 0) = 2;
    B(0, 1) = 8;
    B(1, 0) = 4;
    B(1, 1) = 9;
    B(2, 0) = 6;
    B(2, 1) = 10;

    DenseMatrix C = A * B;
    REQUIRE(C(0, 0) == 28);
    REQUIRE(C(0, 1) == 56);
    REQUIRE(C(1, 0) == 64);
    REQUIRE(C(1, 1) == 137);
  }
}