#ifndef CRS_MATRIX_HPP #define CRS_MATRIX_HPP #include <algebra/Vector.hpp> #include <utils/Array.hpp> #include <utils/NaNHelper.hpp> #include <utils/PugsAssert.hpp> #include <iostream> template <typename DataType, typename IndexType> class CRSMatrixDescriptor; template <typename DataType = double, typename IndexType = int> class CRSMatrix { public: using MutableDataType = std::remove_const_t<DataType>; using index_type = IndexType; using data_type = DataType; private: const IndexType m_nb_rows; const IndexType m_nb_columns; Array<const IndexType> m_row_map; Array<const DataType> m_values; Array<const IndexType> m_column_indices; public: PUGS_INLINE bool isSquare() const noexcept { return m_nb_rows == m_nb_columns; } PUGS_INLINE IndexType numberOfRows() const { return m_nb_rows; } PUGS_INLINE IndexType numberOfColumns() const { return m_nb_columns; } auto values() const { return m_values; } auto rowMap() const { return m_row_map; } auto columnIndices() const { return m_column_indices; } template <typename DataType2> Vector<MutableDataType> operator*(const Vector<DataType2>& x) const { static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), "Cannot multiply matrix and vector of different type"); Assert(x.size() - m_nb_columns == 0); Vector<MutableDataType> Ax(m_nb_rows); parallel_for( m_nb_rows, PUGS_LAMBDA(const IndexType& i_row) { const auto row_begin = m_row_map[i_row]; const auto row_end = m_row_map[i_row + 1]; MutableDataType sum{0}; for (IndexType j = row_begin; j < row_end; ++j) { sum += m_values[j] * x[m_column_indices[j]]; } Ax[i_row] = sum; }); return Ax; } friend PUGS_INLINE std::ostream& operator<<(std::ostream& os, const CRSMatrix& A) { for (IndexType i = 0; i < A.m_nb_rows; ++i) { const auto row_begin = A.m_row_map[i]; const auto row_end = A.m_row_map[i + 1]; os << i << "|"; for (IndexType j = row_begin; j < row_end; ++j) { os << ' ' << A.m_column_indices[j] << ':' << NaNHelper(A.m_values[j]); } os << '\n'; } return os; } private: friend class CRSMatrixDescriptor<DataType, IndexType>; CRSMatrix(const IndexType nb_rows, const IndexType nb_columns, const Array<const IndexType>& row_map, const Array<const DataType>& values, const Array<const IndexType>& column_indices) : m_nb_rows{nb_rows}, m_nb_columns{nb_columns}, m_row_map{row_map}, m_values{values}, m_column_indices{column_indices} {} public: ~CRSMatrix() = default; }; #endif // CRS_MATRIX_HPP