Skip to content
Snippets Groups Projects
Commit 90ac3e2c authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tests for CRSMatrix and little clean-up

parent 3c154f08
No related branches found
No related tags found
1 merge request!26Feature/linear systems
...@@ -21,6 +21,13 @@ class CRSMatrix ...@@ -21,6 +21,13 @@ class CRSMatrix
Array<DataType> m_values; Array<DataType> m_values;
public: public:
PUGS_INLINE
size_t
numberOfRows() const
{
return m_host_matrix.numRows();
}
Vector<DataType> operator*(const Vector<DataType>& x) const Vector<DataType> operator*(const Vector<DataType>& x) const
{ {
Vector<DataType> Ax{m_host_matrix.numRows()}; Vector<DataType> Ax{m_host_matrix.numRows()};
...@@ -40,22 +47,6 @@ class CRSMatrix ...@@ -40,22 +47,6 @@ class CRSMatrix
return Ax; return Ax;
} }
friend std::ostream&
operator<<(std::ostream& os, const CRSMatrix& M)
{
auto host_row_map = M.m_host_matrix.row_map;
for (IndexType i_row = 0; i_row < M.m_host_matrix.numRows(); ++i_row) {
const auto& row_begin = host_row_map(i_row);
const auto& row_end = host_row_map(i_row + 1);
os << i_row << " #";
for (IndexType j = row_begin; j < row_end; ++j) {
os << ' ' << M.m_host_matrix.entries(j) << ':' << M.m_values[j];
}
os << '\n';
}
return os;
}
CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M) CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M)
{ {
m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector()); m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector());
......
...@@ -5,6 +5,7 @@ add_executable (unit_tests ...@@ -5,6 +5,7 @@ add_executable (unit_tests
test_main.cpp test_main.cpp
test_Array.cpp test_Array.cpp
test_ArrayUtils.cpp test_ArrayUtils.cpp
test_CRSMatrix.cpp
test_ItemType.cpp test_ItemType.cpp
test_PugsAssert.cpp test_PugsAssert.cpp
test_RevisionInfo.cpp test_RevisionInfo.cpp
......
#include <catch2/catch.hpp>
#include <CRSMatrix.hpp>
// Instantiate to ensure full coverage is performed
template class SparseMatrixDescriptor<int, uint8_t>;
TEST_CASE("CRSMatrix", "[algebra]")
{
SECTION("CRSMatrix")
{
SECTION("matrix size")
{
SparseMatrixDescriptor<int, uint8_t> S{5};
S(0, 2) = 3;
S(1, 2) = 11;
S(1, 1) = 1;
S(3, 0) = 5;
S(3, 1) = -3;
S(4, 1) = 1;
S(2, 2) = 4;
S(4, 3) = 2;
S(4, 4) = -2;
CRSMatrix<int, uint8_t> A{S};
REQUIRE(A.numberOfRows() == S.numberOfRows());
}
SECTION("matrix vector product (simple)")
{
SparseMatrixDescriptor<int, uint8_t> S{5};
S(0, 2) = 3;
S(1, 2) = 11;
S(1, 1) = 1;
S(3, 0) = 5;
S(3, 1) = -3;
S(4, 1) = 1;
S(2, 2) = 4;
S(4, 3) = 2;
S(4, 4) = -2;
CRSMatrix<int, uint8_t> A{S};
Vector<int> x{A.numberOfRows()};
x = 0;
x[0] = 1;
Vector<int> y = A * x;
REQUIRE(y[0] == 0);
REQUIRE(y[1] == 0);
REQUIRE(y[2] == 0);
REQUIRE(y[3] == 5);
REQUIRE(y[4] == 0);
x = 0;
x[1] = 2;
y = A * x;
REQUIRE(y[0] == 0);
REQUIRE(y[1] == 2);
REQUIRE(y[2] == 0);
REQUIRE(y[3] == -6);
REQUIRE(y[4] == 2);
x = 0;
x[2] = -1;
y = A * x;
REQUIRE(y[0] == -3);
REQUIRE(y[1] == -11);
REQUIRE(y[2] == -4);
REQUIRE(y[3] == 0);
REQUIRE(y[4] == 0);
x = 0;
x[3] = 3;
y = A * x;
REQUIRE(y[0] == 0);
REQUIRE(y[1] == 0);
REQUIRE(y[2] == 0);
REQUIRE(y[3] == 0);
REQUIRE(y[4] == 6);
x = 0;
x[4] = 1;
y = A * x;
REQUIRE(y[0] == 0);
REQUIRE(y[1] == 0);
REQUIRE(y[2] == 0);
REQUIRE(y[3] == 0);
REQUIRE(y[4] == -2);
}
SECTION("matrix vector product (complet)")
{
SparseMatrixDescriptor<int, uint8_t> S{4};
S(0, 0) = 1;
S(0, 1) = 2;
S(0, 2) = 3;
S(0, 3) = 4;
S(1, 0) = 5;
S(1, 1) = 6;
S(1, 2) = 7;
S(1, 3) = 8;
S(2, 0) = 9;
S(2, 1) = 10;
S(2, 2) = 11;
S(2, 3) = 12;
S(3, 0) = 13;
S(3, 1) = 14;
S(3, 2) = 15;
S(3, 3) = 16;
CRSMatrix<int, uint8_t> A{S};
Vector<int> x{A.numberOfRows()};
x[0] = 1;
x[1] = 2;
x[2] = 3;
x[3] = 4;
Vector<int> y = A * x;
REQUIRE(y[0] == 30);
REQUIRE(y[1] == 70);
REQUIRE(y[2] == 110);
REQUIRE(y[3] == 150);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment