diff --git a/src/algebra/DenseMatrix.hpp b/src/algebra/DenseMatrix.hpp index a48dc18710d87f870e1101f74b1b7591de9996ca..8f15f125861d0cc9cce0aef716e63959f736d36b 100644 --- a/src/algebra/DenseMatrix.hpp +++ b/src/algebra/DenseMatrix.hpp @@ -35,13 +35,13 @@ class DenseMatrix // LCOV_EXCL_LINE return m_nb_rows == m_nb_columns; } - friend DenseMatrix<std::remove_const_t<DataType>> + friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> copy(const DenseMatrix& A) noexcept { return {A.m_nb_rows, A.m_nb_columns, copy(A.m_values)}; } - friend DenseMatrix<std::remove_const_t<DataType>> + friend PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> transpose(const DenseMatrix& A) { DenseMatrix<std::remove_const_t<DataType>> A_transpose{A.m_nb_columns, A.m_nb_rows}; @@ -53,7 +53,7 @@ class DenseMatrix // LCOV_EXCL_LINE return A_transpose; } - friend DenseMatrix + friend PUGS_INLINE DenseMatrix operator*(const DataType& a, const DenseMatrix& A) { DenseMatrix<std::remove_const_t<DataType>> aA = copy(A); @@ -61,14 +61,16 @@ class DenseMatrix // LCOV_EXCL_LINE } template <typename DataType2> - PUGS_INLINE Vector<std::remove_const_t<DataType2>> + PUGS_INLINE Vector<std::remove_const_t<DataType>> operator*(const Vector<DataType2>& x) const { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_columns == x.size()); const DenseMatrix& A = *this; - Vector<std::remove_const_t<DataType2>> Ax{m_nb_rows}; + Vector<std::remove_const_t<DataType>> Ax{m_nb_rows}; for (size_t i = 0; i < m_nb_rows; ++i) { - DataType2 Axi = A(i, 0) * x[0]; + std::remove_const_t<DataType> Axi = A(i, 0) * x[0]; for (size_t j = 1; j < m_nb_columns; ++j) { Axi += A(i, j) * x[j]; } @@ -78,16 +80,18 @@ class DenseMatrix // LCOV_EXCL_LINE } template <typename DataType2> - PUGS_INLINE DenseMatrix<std::remove_const_t<DataType2>> + PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> operator*(const DenseMatrix<DataType2>& B) const { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_columns == B.numberOfRows()); const DenseMatrix& A = *this; DenseMatrix<std::remove_const_t<DataType>> AB{m_nb_rows, B.numberOfColumns()}; for (size_t i = 0; i < m_nb_rows; ++i) { for (size_t j = 0; j < B.numberOfColumns(); ++j) { - DataType2 ABij = 0; + std::remove_const_t<DataType> ABij = 0; for (size_t k = 0; k < m_nb_columns; ++k) { ABij += A(i, k) * B(k, j); } @@ -118,6 +122,8 @@ class DenseMatrix // LCOV_EXCL_LINE PUGS_INLINE DenseMatrix& operator-=(const DenseMatrix<DataType2>& B) { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_columns == B.numberOfColumns()); @@ -130,6 +136,8 @@ class DenseMatrix // LCOV_EXCL_LINE PUGS_INLINE DenseMatrix& operator+=(const DenseMatrix<DataType2>& B) { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_columns == B.numberOfColumns()); @@ -142,6 +150,8 @@ class DenseMatrix // LCOV_EXCL_LINE PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> operator+(const DenseMatrix<DataType2>& B) const { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_columns == B.numberOfColumns()); DenseMatrix<std::remove_const_t<DataType>> sum{B.numberOfRows(), B.numberOfColumns()}; @@ -156,6 +166,8 @@ class DenseMatrix // LCOV_EXCL_LINE PUGS_INLINE DenseMatrix<std::remove_const_t<DataType>> operator-(const DenseMatrix<DataType2>& B) const { + static_assert(std::is_same_v<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>, + "incompatible data types"); Assert(m_nb_rows == B.numberOfRows()); Assert(m_nb_columns == B.numberOfColumns()); DenseMatrix<std::remove_const_t<DataType>> difference{B.numberOfRows(), B.numberOfColumns()};