Skip to content
Snippets Groups Projects
Commit 643c7c2b authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add DenseMatrix class

This class consists actually in a renaming and cleaning of the class
LocalRectangularMatrix written by Julie.

Matrices of this type are dense, rectangular and local (cannot
represent a matrix spread on various CPUs).
parent 2d05f6b5
No related branches found
No related tags found
1 merge request!93Do not initializa Kokkos Arrays anymore
#ifndef DENSE_MATRIX_HPP
#define DENSE_MATRIX_HPP
#include <algebra/TinyMatrix.hpp>
#include <algebra/Vector.hpp>
#include <utils/Array.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/PugsUtils.hpp>
#include <utils/Types.hpp>
template <typename DataType>
class DenseMatrix // LCOV_EXCL_LINE
{
public:
using data_type = DataType;
using index_type = size_t;
private:
size_t m_nb_rows;
size_t m_nb_columns;
Array<DataType> m_values;
static_assert(std::is_same_v<typename decltype(m_values)::index_type, index_type>);
static_assert(std::is_arithmetic_v<DataType>, "Dense matrices expect arithmetic data");
// Allows const version to access our data
friend DenseMatrix<std::add_const_t<DataType>>;
public:
PUGS_INLINE
bool
isSquare() const noexcept
{
return m_nb_rows == m_nb_columns;
}
friend DenseMatrix<std::remove_const_t<DataType>>
copy(const DenseMatrix& A) noexcept
{
return {A.m_nb_rows, A.m_nb_columns, copy(A.m_values)};
}
friend DenseMatrix<std::remove_const_t<DataType>>
transpose(const DenseMatrix& A)
{
DenseMatrix<std::remove_const_t<DataType>> A_transpose{A.m_nb_columns, A.m_nb_rows};
for (size_t i = 0; i < A.m_nb_rows; ++i) {
for (size_t j = 0; j < A.m_nb_columns; ++j) {
A_transpose(j, i) = A(i, j);
}
}
return A_transpose;
}
friend DenseMatrix
operator*(const DataType& a, const DenseMatrix& A)
{
DenseMatrix<std::remove_const_t<DataType>> aA = copy(A);
return aA *= a;
}
template <typename DataType2>
PUGS_INLINE Vector<std::remove_const_t<DataType2>>
operator*(const Vector<DataType2>& x) const
{
Assert(m_nb_columns == x.size());
const DenseMatrix& A = *this;
Vector<std::remove_const_t<DataType2>> Ax{m_nb_rows};
for (size_t i = 0; i < m_nb_rows; ++i) {
DataType2 Axi = A(i, 0) * x[0];
for (size_t j = 1; j < m_nb_columns; ++j) {
Axi += A(i, j) * x[j];
}
Ax[i] = Axi;
}
return Ax;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType2>>
operator*(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_columns == B.nbRows());
const DenseMatrix& A = *this;
DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.nbColumns()};
for (size_t i = 0; i < m_nb_rows; ++i) {
for (size_t j = 0; j < B.nbColumns(); ++j) {
DataType2 ABij = 0;
for (size_t k = 0; k < m_nb_columns; ++k) {
ABij += A(i, k) * B(k, j);
}
AB(i, j) = ABij;
}
}
return AB;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix&
operator/=(const DataType2& a)
{
const auto inv_a = 1. / a;
return (*this) *= inv_a;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix&
operator*=(const DataType2& a)
{
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] *= a; });
return *this;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix&
operator-=(const DenseMatrix<DataType2>& B)
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] -= B.m_values[i]; });
return *this;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix&
operator+=(const DenseMatrix<DataType2>& B)
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] += B.m_values[i]; });
return *this;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator+(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
DenseMatrix<std::remove_const_t<DataType>> sum{B.nbRows(), B.nbColumns()};
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { sum.m_values[i] = m_values[i] + B.m_values[i]; });
return sum;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator-(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
DenseMatrix<std::remove_const_t<DataType>> difference{B.nbRows(), B.nbColumns()};
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { difference.m_values[i] = m_values[i] - B.m_values[i]; });
return difference;
}
PUGS_INLINE
DataType&
operator()(index_type i, index_type j) const noexcept(NO_ASSERT)
{
Assert((i < m_nb_rows and j < m_nb_columns), "invalid indices");
return m_values[i * m_nb_columns + j];
}
PUGS_INLINE
size_t
nbRows() const noexcept
{
return m_nb_rows;
}
PUGS_INLINE
size_t
nbColumns() const noexcept
{
return m_nb_columns;
}
PUGS_INLINE void
fill(const DataType& value) noexcept
{
m_values.fill(value);
}
PUGS_INLINE DenseMatrix& operator=(ZeroType) noexcept
{
m_values.fill(0);
return *this;
}
PUGS_INLINE DenseMatrix& operator=(IdentityType) noexcept(NO_ASSERT)
{
Assert(m_nb_rows == m_nb_columns, "Identity must be a square matrix");
m_values.fill(0);
parallel_for(
m_nb_rows, PUGS_LAMBDA(const index_type i) { m_values[i * m_nb_rows + i] = 1; });
return *this;
}
template <typename DataType2>
PUGS_INLINE DenseMatrix&
operator=(const DenseMatrix<DataType2>& A) noexcept
{
// ensures that DataType is the same as source DataType2
static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
"Cannot assign DenseMatrix of different type");
// ensures that const is not lost through copy
static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()),
"Cannot assign DenseMatrix of const to DenseMatrix of non-const");
m_nb_rows = A.m_nb_rows;
m_nb_columns = A.m_nb_columns;
m_values = A.m_values;
return *this;
}
PUGS_INLINE
DenseMatrix& operator=(const DenseMatrix&) = default;
PUGS_INLINE
DenseMatrix& operator=(DenseMatrix&&) = default;
template <typename DataType2>
DenseMatrix(const DenseMatrix<DataType2>& A)
{
// ensures that DataType is the same as source DataType2
static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
"Cannot assign DenseMatrix of different type");
// ensures that const is not lost through copy
static_assert(((std::is_const<DataType2>() and std::is_const<DataType>()) or not std::is_const<DataType2>()),
"Cannot assign DenseMatrix of const to DenseMatrix of non-const");
this->operator=(A);
}
DenseMatrix(const DenseMatrix&) = default;
DenseMatrix(DenseMatrix&&) = default;
explicit DenseMatrix(size_t nb_rows, size_t nb_columns) noexcept
: m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_values{nb_rows * nb_columns}
{}
explicit DenseMatrix(size_t nb_rows) noexcept : m_nb_rows{nb_rows}, m_nb_columns{nb_rows}, m_values{nb_rows * nb_rows}
{}
template <size_t N>
explicit DenseMatrix(const TinyMatrix<N, DataType>& M) noexcept : m_nb_rows{N}, m_nb_columns{N}, m_values{N * N}
{
parallel_for(
N, PUGS_LAMBDA(const index_type i) {
for (size_t j = 0; j < N; ++j) {
m_values[i * N + j] = M(i, j);
}
});
}
private:
DenseMatrix(size_t nb_rows, size_t nb_columns, const Array<DataType> values) noexcept(NO_ASSERT)
: m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_values{values}
{
Assert(m_values.size() == m_nb_rows * m_nb_columns, "incompatible sizes");
}
public:
~DenseMatrix() = default;
};
#endif // DENSE_MATRIX_HPP
......@@ -57,6 +57,7 @@ add_executable (unit_tests
test_CRSGraph.cpp
test_CRSMatrix.cpp
test_DataVariant.cpp
test_DenseMatrix.cpp
test_Demangle.cpp
test_DiscreteFunctionDescriptorP0.cpp
test_DiscreteFunctionDescriptorP0Vector.cpp
......
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>
#include <utils/PugsAssert.hpp>
#include <algebra/DenseMatrix.hpp>
#include <algebra/Vector.hpp>
// Instantiate to ensure full coverage is performed
template class DenseMatrix<int>;
// clazy:excludeall=non-pod-global-static
TEST_CASE("DenseMatrix", "[algebra]")
{
SECTION("size")
{
DenseMatrix<int> A{2, 3};
REQUIRE(A.nbRows() == 2);
REQUIRE(A.nbColumns() == 3);
}
SECTION("write access")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
REQUIRE(A(0, 0) == 0);
REQUIRE(A(0, 1) == 1);
REQUIRE(A(0, 2) == 2);
REQUIRE(A(1, 0) == 3);
REQUIRE(A(1, 1) == 4);
REQUIRE(A(1, 2) == 5);
DenseMatrix<const int> const_A = A;
REQUIRE(const_A(0, 0) == 0);
REQUIRE(const_A(0, 1) == 1);
REQUIRE(const_A(0, 2) == 2);
REQUIRE(const_A(1, 0) == 3);
REQUIRE(const_A(1, 1) == 4);
REQUIRE(const_A(1, 2) == 5);
}
SECTION("fill")
{
DenseMatrix<int> A{2, 3};
A.fill(2);
REQUIRE(A(0, 0) == 2);
REQUIRE(A(0, 1) == 2);
REQUIRE(A(0, 2) == 2);
REQUIRE(A(1, 0) == 2);
REQUIRE(A(1, 1) == 2);
REQUIRE(A(1, 2) == 2);
}
SECTION("copy constructor (shallow)")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
const DenseMatrix<int> B = A;
REQUIRE(B(0, 0) == 0);
REQUIRE(B(0, 1) == 1);
REQUIRE(B(0, 2) == 2);
REQUIRE(B(1, 0) == 3);
REQUIRE(B(1, 1) == 4);
REQUIRE(B(1, 2) == 5);
}
SECTION("copy (deep)")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
const DenseMatrix<int> B = copy(A);
A(0, 0) = 10;
A(0, 1) = 11;
A(0, 2) = 12;
A(1, 0) = 13;
A(1, 1) = 14;
A(1, 2) = 15;
REQUIRE(B(0, 0) == 0);
REQUIRE(B(0, 1) == 1);
REQUIRE(B(0, 2) == 2);
REQUIRE(B(1, 0) == 3);
REQUIRE(B(1, 1) == 4);
REQUIRE(B(1, 2) == 5);
REQUIRE(A(0, 0) == 10);
REQUIRE(A(0, 1) == 11);
REQUIRE(A(0, 2) == 12);
REQUIRE(A(1, 0) == 13);
REQUIRE(A(1, 1) == 14);
REQUIRE(A(1, 2) == 15);
}
SECTION("self scalar multiplication")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
A *= 2;
REQUIRE(A(0, 0) == 0);
REQUIRE(A(0, 1) == 2);
REQUIRE(A(0, 2) == 4);
REQUIRE(A(1, 0) == 6);
REQUIRE(A(1, 1) == 8);
REQUIRE(A(1, 2) == 10);
}
SECTION("left scalar multiplication")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
const DenseMatrix<int> B = 2 * A;
REQUIRE(B(0, 0) == 0);
REQUIRE(B(0, 1) == 2);
REQUIRE(B(0, 2) == 4);
REQUIRE(B(1, 0) == 6);
REQUIRE(B(1, 1) == 8);
REQUIRE(B(1, 2) == 10);
}
SECTION("product matrix vector")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 6;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
Vector<int> x{3};
x[0] = 7;
x[1] = 3;
x[2] = 4;
Vector y = A * x;
REQUIRE(y[0] == 53);
REQUIRE(y[1] == 53);
}
SECTION("self scalar division")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 2;
A(0, 2) = 4;
A(1, 0) = 6;
A(1, 1) = 8;
A(1, 2) = 10;
A /= 2;
REQUIRE(A(0, 0) == 0);
REQUIRE(A(0, 1) == 1);
REQUIRE(A(0, 2) == 2);
REQUIRE(A(1, 0) == 3);
REQUIRE(A(1, 1) == 4);
REQUIRE(A(1, 2) == 5);
}
SECTION("self minus")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
DenseMatrix<int> B{2, 3};
B(0, 0) = 5;
B(0, 1) = 6;
B(0, 2) = 4;
B(1, 0) = 2;
B(1, 1) = 1;
B(1, 2) = 3;
A -= B;
REQUIRE(A(0, 0) == -5);
REQUIRE(A(0, 1) == -5);
REQUIRE(A(0, 2) == -2);
REQUIRE(A(1, 0) == 1);
REQUIRE(A(1, 1) == 3);
REQUIRE(A(1, 2) == 2);
}
SECTION("self sum")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
DenseMatrix<int> B{2, 3};
B(0, 0) = 5;
B(0, 1) = 6;
B(0, 2) = 4;
B(1, 0) = 2;
B(1, 1) = 1;
B(1, 2) = 3;
A += B;
REQUIRE(A(0, 0) == 5);
REQUIRE(A(0, 1) == 7);
REQUIRE(A(0, 2) == 6);
REQUIRE(A(1, 0) == 5);
REQUIRE(A(1, 1) == 5);
REQUIRE(A(1, 2) == 8);
}
SECTION("sum")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 6;
A(0, 1) = 5;
A(0, 2) = 4;
A(1, 0) = 3;
A(1, 1) = 2;
A(1, 2) = 1;
DenseMatrix<int> B{2, 3};
B(0, 0) = 0;
B(0, 1) = 1;
B(0, 2) = 2;
B(1, 0) = 3;
B(1, 1) = 4;
B(1, 2) = 5;
DenseMatrix C = A + B;
REQUIRE(C(0, 0) == 6);
REQUIRE(C(0, 1) == 6);
REQUIRE(C(0, 2) == 6);
REQUIRE(C(1, 0) == 6);
REQUIRE(C(1, 1) == 6);
REQUIRE(C(1, 2) == 6);
}
SECTION("difference")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 6;
A(0, 1) = 5;
A(0, 2) = 4;
A(1, 0) = 3;
A(1, 1) = 2;
A(1, 2) = 1;
DenseMatrix<int> B{2, 3};
B(0, 0) = 0;
B(0, 1) = 1;
B(0, 2) = 2;
B(1, 0) = 3;
B(1, 1) = 4;
B(1, 2) = 5;
DenseMatrix C = A - B;
REQUIRE(C(0, 0) == 6);
REQUIRE(C(0, 1) == 4);
REQUIRE(C(0, 2) == 2);
REQUIRE(C(1, 0) == 0);
REQUIRE(C(1, 1) == -2);
REQUIRE(C(1, 2) == -4);
}
SECTION("transpose")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 0;
A(0, 1) = 1;
A(0, 2) = 2;
A(1, 0) = 3;
A(1, 1) = 4;
A(1, 2) = 5;
DenseMatrix B = transpose(A);
REQUIRE(B(0, 0) == 0);
REQUIRE(B(0, 1) == 3);
REQUIRE(B(1, 0) == 1);
REQUIRE(B(1, 1) == 4);
REQUIRE(B(2, 0) == 2);
REQUIRE(B(2, 1) == 5);
}
SECTION("product matrix vector")
{
DenseMatrix<int> A{2, 3};
A(0, 0) = 1;
A(0, 1) = 2;
A(0, 2) = 3;
A(1, 0) = 4;
A(1, 1) = 5;
A(1, 2) = 6;
DenseMatrix<int> B{3, 2};
B(0, 0) = 2;
B(0, 1) = 8;
B(1, 0) = 4;
B(1, 1) = 9;
B(2, 0) = 6;
B(2, 1) = 10;
DenseMatrix C = A * B;
REQUIRE(C(0, 0) == 28);
REQUIRE(C(0, 1) == 56);
REQUIRE(C(1, 0) == 64);
REQUIRE(C(1, 1) == 137);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment