#ifndef CRS_MATRIX_HPP #define CRS_MATRIX_HPP #include <Array.hpp> #include <Kokkos_StaticCrsGraph.hpp> #include <PugsAssert.hpp> #include <SparseMatrixDescriptor.hpp> #include <Vector.hpp> #include <iostream> template <typename DataType = double, typename IndexType = size_t> class CRSMatrix { using MutableDataType = std::remove_const_t<DataType>; private: using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>; HostMatrix m_host_matrix; Array<const DataType> m_values; public: PUGS_INLINE size_t numberOfRows() const { return m_host_matrix.numRows(); } template <typename DataType2> Vector<MutableDataType> operator*(const Vector<DataType2>& x) const { static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(), "Cannot multiply matrix and vector of different type"); Assert(x.size() == m_host_matrix.numRows()); Vector<MutableDataType> Ax{m_host_matrix.numRows()}; auto host_row_map = m_host_matrix.row_map; parallel_for( m_host_matrix.numRows(), PUGS_LAMBDA(const IndexType& i_row) { const auto row_begin = host_row_map(i_row); const auto row_end = host_row_map(i_row + 1); MutableDataType sum{0}; for (IndexType j = row_begin; j < row_end; ++j) { sum += m_values[j] * x[m_host_matrix.entries(j)]; } Ax[i_row] = sum; }); return Ax; } CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M) { m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector()); m_values = M.valueArray(); } ~CRSMatrix() = default; }; #endif // CRS_MATRIX_HPP