Skip to content
Snippets Groups Projects
Commit ea4e4479 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Fix a bunch of type checkings

parent 7f0c7264
Branches
Tags
1 merge request!93Do not initializa Kokkos Arrays anymore
...@@ -35,13 +35,13 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -35,13 +35,13 @@ class DenseMatrix // LCOV_EXCL_LINE
return m_nb_rows == m_nb_columns; return m_nb_rows == m_nb_columns;
} }
friend DenseMatrix<std::remove_const_t<DataType>> friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
copy(const DenseMatrix& A) noexcept copy(const DenseMatrix& A) noexcept
{ {
return {A.m_nb_rows, A.m_nb_columns, copy(A.m_values)}; return {A.m_nb_rows, A.m_nb_columns, copy(A.m_values)};
} }
friend DenseMatrix<std::remove_const_t<DataType>> friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
transpose(const DenseMatrix& A) transpose(const DenseMatrix& A)
{ {
DenseMatrix<std::remove_const_t<DataType>> A_transpose{A.m_nb_columns, A.m_nb_rows}; DenseMatrix<std::remove_const_t<DataType>> A_transpose{A.m_nb_columns, A.m_nb_rows};
...@@ -53,7 +53,7 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -53,7 +53,7 @@ class DenseMatrix // LCOV_EXCL_LINE
return A_transpose; return A_transpose;
} }
friend DenseMatrix friend PUGS_INLINE DenseMatrix
operator*(const DataType& a, const DenseMatrix& A) operator*(const DataType& a, const DenseMatrix& A)
{ {
DenseMatrix<std::remove_const_t<DataType>> aA = copy(A); DenseMatrix<std::remove_const_t<DataType>> aA = copy(A);
...@@ -61,14 +61,16 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -61,14 +61,16 @@ class DenseMatrix // LCOV_EXCL_LINE
} }
template <typename DataType2> template <typename DataType2>
PUGS_INLINE Vector<std::remove_const_t<DataType2>> PUGS_INLINE Vector<std::remove_const_t<DataType>>
operator*(const Vector<DataType2>& x) const operator*(const Vector<DataType2>& x) const
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_columns == x.size()); Assert(m_nb_columns == x.size());
const DenseMatrix& A = *this; const DenseMatrix& A = *this;
Vector<std::remove_const_t<DataType2>> Ax{m_nb_rows}; Vector<std::remove_const_t<DataType>> Ax{m_nb_rows};
for (size_t i = 0; i < m_nb_rows; ++i) { for (size_t i = 0; i < m_nb_rows; ++i) {
DataType2 Axi = A(i, 0) * x[0]; std::remove_const_t<DataType> Axi = A(i, 0) * x[0];
for (size_t j = 1; j < m_nb_columns; ++j) { for (size_t j = 1; j < m_nb_columns; ++j) {
Axi += A(i, j) * x[j]; Axi += A(i, j) * x[j];
} }
...@@ -78,16 +80,18 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -78,16 +80,18 @@ class DenseMatrix // LCOV_EXCL_LINE
} }
template <typename DataType2> template <typename DataType2>
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType2>> PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator*(const DenseMatrix<DataType2>& B) const operator*(const DenseMatrix<DataType2>& B) const
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_columns == B.numberOfRows()); Assert(m_nb_columns == B.numberOfRows());
const DenseMatrix& A = *this; const DenseMatrix& A = *this;
DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.numberOfColumns()}; DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.numberOfColumns()};
for (size_t i = 0; i < m_nb_rows; ++i) { for (size_t i = 0; i < m_nb_rows; ++i) {
for (size_t j = 0; j < B.numberOfColumns(); ++j) { for (size_t j = 0; j < B.numberOfColumns(); ++j) {
DataType2 ABij = 0; std::remove_const_t<DataType> ABij = 0;
for (size_t k = 0; k < m_nb_columns; ++k) { for (size_t k = 0; k < m_nb_columns; ++k) {
ABij += A(i, k) * B(k, j); ABij += A(i, k) * B(k, j);
} }
...@@ -118,6 +122,8 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -118,6 +122,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix& PUGS_INLINE DenseMatrix&
operator-=(const DenseMatrix<DataType2>& B) operator-=(const DenseMatrix<DataType2>& B)
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns()); Assert(m_nb_columns == B.numberOfColumns());
...@@ -130,6 +136,8 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -130,6 +136,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix& PUGS_INLINE DenseMatrix&
operator+=(const DenseMatrix<DataType2>& B) operator+=(const DenseMatrix<DataType2>& B)
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns()); Assert(m_nb_columns == B.numberOfColumns());
...@@ -142,6 +150,8 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -142,6 +150,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator+(const DenseMatrix<DataType2>& B) const operator+(const DenseMatrix<DataType2>& B) const
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns()); Assert(m_nb_columns == B.numberOfColumns());
DenseMatrix<std::remove_const_t<DataType>> sum{B.numberOfRows(), B.numberOfColumns()}; DenseMatrix<std::remove_const_t<DataType>> sum{B.numberOfRows(), B.numberOfColumns()};
...@@ -156,6 +166,8 @@ class DenseMatrix // LCOV_EXCL_LINE ...@@ -156,6 +166,8 @@ class DenseMatrix // LCOV_EXCL_LINE
PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>>
operator-(const DenseMatrix<DataType2>& B) const operator-(const DenseMatrix<DataType2>& B) const
{ {
static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>,
"incompatible data types");
Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_rows == B.numberOfRows());
Assert(m_nb_columns == B.numberOfColumns()); Assert(m_nb_columns == B.numberOfColumns());
DenseMatrix<std::remove_const_t<DataType>> difference{B.numberOfRows(), B.numberOfColumns()}; DenseMatrix<std::remove_const_t<DataType>> difference{B.numberOfRows(), B.numberOfColumns()};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment