Select Git revision
ob-pugs-error.el
-
Stéphane Del Pino authoredStéphane Del Pino authored
CRSMatrix.hpp 1.68 KiB
#ifndef CRS_MATRIX_HPP
#define CRS_MATRIX_HPP
#include <Array.hpp>
#include <Kokkos_StaticCrsGraph.hpp>
#include <PugsAssert.hpp>
#include <SparseMatrixDescriptor.hpp>
#include <Vector.hpp>
#include <iostream>
template <typename DataType = double, typename IndexType = size_t>
class CRSMatrix
{
using MutableDataType = std::remove_const_t<DataType>;
private:
using HostMatrix = Kokkos::StaticCrsGraph<IndexType, Kokkos::HostSpace>;
HostMatrix m_host_matrix;
Array<const DataType> m_values;
public:
PUGS_INLINE
size_t
numberOfRows() const
{
return m_host_matrix.numRows();
}
template <typename DataType2>
Vector<MutableDataType>
operator*(const Vector<DataType2>& x) const
{
static_assert(std::is_same<std::remove_const_t<DataType>, std::remove_const_t<DataType2>>(),
"Cannot multiply matrix and vector of different type");
Assert(x.size() == m_host_matrix.numRows());
Vector<MutableDataType> Ax{m_host_matrix.numRows()};
auto host_row_map = m_host_matrix.row_map;
parallel_for(
m_host_matrix.numRows(), PUGS_LAMBDA(const IndexType& i_row) {
const auto row_begin = host_row_map(i_row);
const auto row_end = host_row_map(i_row + 1);
MutableDataType sum{0};
for (IndexType j = row_begin; j < row_end; ++j) {
sum += m_values[j] * x[m_host_matrix.entries(j)];
}
Ax[i_row] = sum;
});
return Ax;
}
CRSMatrix(const SparseMatrixDescriptor<DataType, IndexType>& M)
{
m_host_matrix = Kokkos::create_staticcrsgraph<HostMatrix>("connectivity_matrix", M.graphVector());
m_values = M.valueArray();
}
~CRSMatrix() = default;
};
#endif // CRS_MATRIX_HPP