#include <catch2/catch.hpp>

#include <utils/PugsAssert.hpp>

#include <algebra/SparseMatrixDescriptor.hpp>

// Instantiate to ensure full coverage is performed
template class SparseMatrixDescriptor<int, uint8_t>;

// clazy:excludeall=non-pod-global-static

TEST_CASE("SparseMatrixDescriptor", "[algebra]")
{
  SECTION("SparseRowDescriptor subclass")
  {
    SECTION("number of values")
    {
      SparseMatrixDescriptor<int, uint8_t>::SparseRowDescriptor r;
      REQUIRE(r.numberOfValues() == 0);

      r[0] = 3;
      r[0] += 2;

      REQUIRE(r.numberOfValues() == 1);

      r[1] = -2;

      REQUIRE(r.numberOfValues() == 2);
    }

    SECTION("values access")
    {
      SparseMatrixDescriptor<int, uint8_t>::SparseRowDescriptor r;

      r[0] = 3;
      r[0] += 2;

      r[3] = 8;

      REQUIRE(r[0] == 5);
      REQUIRE(r[3] == 8);

      const auto& const_r = r;
      REQUIRE(const_r[0] == 5);
      REQUIRE(const_r[3] == 8);

#ifndef NDEBUG
      REQUIRE_THROWS_AS(const_r[2], AssertError);
#endif   // NDEBUG
    }
  }

  SECTION("number of columns")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    REQUIRE(S.numberOfRows() == 5);
  }

  SECTION("location operators")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 1) = 1;
    S(1, 2) = 11;
    S(0, 2) += 2;
    S(2, 2) = 4;
    S(3, 1) = 5;
    S(4, 1) = 1;
    S(4, 4) = -2;

    REQUIRE(S.row(0).numberOfValues() == 1);
    REQUIRE(S.row(1).numberOfValues() == 2);
    REQUIRE(S.row(2).numberOfValues() == 1);
    REQUIRE(S.row(3).numberOfValues() == 1);
    REQUIRE(S.row(4).numberOfValues() == 2);

    const auto& const_S = S;

    REQUIRE(const_S.row(0).numberOfValues() == 1);
    REQUIRE(const_S.row(1).numberOfValues() == 2);
    REQUIRE(const_S.row(2).numberOfValues() == 1);
    REQUIRE(const_S.row(3).numberOfValues() == 1);
    REQUIRE(const_S.row(4).numberOfValues() == 2);

#ifndef NDEBUG
    REQUIRE_THROWS_AS(S.row(5), AssertError);
    REQUIRE_THROWS_AS(const_S.row(5), AssertError);
#endif   // NDEBUG

    REQUIRE(S(0, 2) == 5);
    REQUIRE(S(1, 1) == 1);
    REQUIRE(S(1, 2) == 11);
    REQUIRE(S(2, 2) == 4);
    REQUIRE(S(3, 1) == 5);
    REQUIRE(S(4, 1) == 1);
    REQUIRE(S(4, 4) == -2);

    REQUIRE(const_S(0, 2) == 5);
    REQUIRE(const_S(1, 1) == 1);
    REQUIRE(const_S(1, 2) == 11);
    REQUIRE(const_S(2, 2) == 4);
    REQUIRE(const_S(3, 1) == 5);
    REQUIRE(const_S(4, 1) == 1);
    REQUIRE(const_S(4, 4) == -2);

#ifndef NDEBUG
    REQUIRE_THROWS_AS(S(5, 0), AssertError);
    REQUIRE_THROWS_AS(const_S(6, 1), AssertError);
    REQUIRE_THROWS_AS(const_S(0, 1), AssertError);
#endif   // NDEBUG
  }

  SECTION("vector-graph")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 2) = 11;
    S(1, 1) = 1;
    S(0, 2) += 2;
    S(3, 3) = 5;
    S(3, 1) = 5;
    S(4, 1) = 1;
    S(2, 2) = 4;
    S(4, 4) = -2;

    const auto graph = S.graphVector();

    REQUIRE(graph.size() == S.numberOfRows());
    REQUIRE(graph[0].size() == 1);
    REQUIRE(graph[1].size() == 2);
    REQUIRE(graph[2].size() == 1);
    REQUIRE(graph[3].size() == 2);
    REQUIRE(graph[4].size() == 2);

    REQUIRE(graph[0][0] == 2);
    REQUIRE(graph[1][0] == 1);
    REQUIRE(graph[1][1] == 2);
    REQUIRE(graph[2][0] == 2);
    REQUIRE(graph[3][0] == 1);
    REQUIRE(graph[3][1] == 3);
    REQUIRE(graph[4][0] == 1);
    REQUIRE(graph[4][1] == 4);
  }

  SECTION("value array")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 2) = 11;
    S(1, 1) = 1;
    S(0, 2) += 2;
    S(3, 3) = 5;
    S(3, 1) = -3;
    S(4, 1) = 1;
    S(2, 2) = 4;
    S(4, 4) = -2;

    const auto value_array = S.valueArray();

    REQUIRE(value_array.size() == 8);

    REQUIRE(value_array[0] == 5);
    REQUIRE(value_array[1] == 1);
    REQUIRE(value_array[2] == 11);
    REQUIRE(value_array[3] == 4);
    REQUIRE(value_array[4] == -3);
    REQUIRE(value_array[5] == 5);
    REQUIRE(value_array[6] == 1);
    REQUIRE(value_array[7] == -2);
  }

  SECTION("output")
  {
    SparseMatrixDescriptor<int, uint8_t> S{5};
    S(0, 2) = 3;
    S(1, 2) = 11;
    S(1, 1) = 1;
    S(0, 2) += 2;
    S(3, 3) = 5;
    S(3, 1) = -3;
    S(4, 1) = 1;
    S(2, 2) = 4;
    S(4, 4) = -2;

    std::ostringstream output;
    output << '\n' << S;

    std::string expected_output = R"(
0 | 2:5
1 | 1:1 2:11
2 | 2:4
3 | 1:-3 3:5
4 | 1:1 4:-2
)";
    REQUIRE(output.str() == expected_output);
  }
}