Skip to content
Snippets Groups Projects
Commit aa284365 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Fix DenseMatrix API

parent e89f73e8
No related branches found
No related tags found
1 merge request!93Do not initializa Kokkos Arrays anymore
......@@ -81,12 +81,12 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType2>>
operator*(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_columns == B.nbRows());
Assert(m_nb_columns == B.numberOfRows());
const DenseMatrix& A = *this;
DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.nbColumns()};
DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.numberOfColumns()};
for (size_t i = 0; i < m_nb_rows; ++i) {
for (size_t j = 0; j < B.nbColumns(); ++j) {
for (size_t j = 0; j < B.numberOfColumns(); ++j) {
DataType2 ABij = 0;
for (size_t k = 0; k < m_nb_columns; ++k) {
ABij += A(i, k) * B(k, j);
......@@ -118,8 +118,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix&
operator-=(const DenseMatrix<DataType2>& B)
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns());
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] -= B.m_values[i]; });
......@@ -130,8 +130,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix&
operator+=(const DenseMatrix<DataType2>& B)
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns());
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { m_values[i] += B.m_values[i]; });
......@@ -142,9 +142,9 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator+(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
DenseMatrix<std::remove_const_t<DataType>> sum{B.nbRows(), B.nbColumns()};
Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns());
DenseMatrix<std::remove_const_t<DataType>> sum{B.numberOfRows(), B.numberOfColumns()};
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { sum.m_values[i] = m_values[i] + B.m_values[i]; });
......@@ -156,9 +156,9 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator-(const DenseMatrix<DataType2>& B) const
{
Assert(m_nb_rows == B.nbRows());
Assert(m_nb_columns == B.nbColumns());
DenseMatrix<std::remove_const_t<DataType>> difference{B.nbRows(), B.nbColumns()};
Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns());
DenseMatrix<std::remove_const_t<DataType>> difference{B.numberOfRows(), B.numberOfColumns()};
parallel_for(
m_values.size(), PUGS_LAMBDA(index_type i) { difference.m_values[i] = m_values[i] - B.m_values[i]; });
......@@ -176,14 +176,14 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE
size_t
nbRows() const noexcept
numberOfRows() const noexcept
{
return m_nb_rows;
}
PUGS_INLINE
size_t
nbColumns() const noexcept
numberOfColumns() const noexcept
{
return m_nb_columns;
}
......@@ -268,6 +268,8 @@ class DenseMatrix // LCOV_EXCL_LINE
});
}
DenseMatrix() noexcept : m_nb_rows{0}, m_nb_columns{0} {}
private:
DenseMatrix(size_t nb_rows, size_t nb_columns, const Array<DataType> values) noexcept(NO_ASSERT)
: m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_values{values}
......
......@@ -53,7 +53,9 @@ class PETScAijMatrixEmbedder
PETScAijMatrixEmbedder(const TinyMatrix<N>& A) : PETScAijMatrixEmbedder{N, N, &A(0, 0)}
{}
PETScAijMatrixEmbedder(const DenseMatrix<double>& A) : PETScAijMatrixEmbedder{A.nbRows(), A.nbColumns(), &A(0, 0)} {}
PETScAijMatrixEmbedder(const DenseMatrix<double>& A)
: PETScAijMatrixEmbedder{A.numberOfRows(), A.numberOfColumns(), &A(0, 0)}
{}
PETScAijMatrixEmbedder(const CRSMatrix<double, size_t>& A);
......
......@@ -16,8 +16,8 @@ TEST_CASE("DenseMatrix", "[algebra]")
SECTION("size")
{
DenseMatrix<int> A{2, 3};
REQUIRE(A.nbRows() == 2);
REQUIRE(A.nbColumns() == 3);
REQUIRE(A.numberOfRows() == 2);
REQUIRE(A.numberOfColumns() == 3);
}
SECTION("write access")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment