Skip to content
Snippets Groups Projects
Commit 5c31fcfb authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add missing dot product for discrete P0 vector functions

parent cd25d595
Branches
Tags
1 merge request!108Add missing dot product for discrete P0 vector functions
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <language/utils/EmbeddedIDiscreteFunctionUtils.hpp> #include <language/utils/EmbeddedIDiscreteFunctionUtils.hpp>
#include <mesh/IMesh.hpp> #include <mesh/IMesh.hpp>
#include <scheme/DiscreteFunctionP0.hpp> #include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <scheme/DiscreteFunctionUtils.hpp> #include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/IDiscreteFunction.hpp> #include <scheme/IDiscreteFunction.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp> #include <scheme/IDiscreteFunctionDescriptor.hpp>
...@@ -318,10 +319,27 @@ template <size_t Dimension> ...@@ -318,10 +319,27 @@ template <size_t Dimension>
std::shared_ptr<const IDiscreteFunction> std::shared_ptr<const IDiscreteFunction>
dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{ {
Assert((f->dataType() == ASTNodeDataType::vector_t and f->descriptor().type() == DiscreteFunctionType::P0) and Assert(((f->descriptor().type() == DiscreteFunctionType::P0Vector) and
(g->descriptor().type() == DiscreteFunctionType::P0Vector)) or
((f->dataType() == ASTNodeDataType::vector_t and f->descriptor().type() == DiscreteFunctionType::P0) and
(g->dataType() == ASTNodeDataType::vector_t and g->descriptor().type() == DiscreteFunctionType::P0) and (g->dataType() == ASTNodeDataType::vector_t and g->descriptor().type() == DiscreteFunctionType::P0) and
(f->dataType().dimension() == g->dataType().dimension())); (f->dataType().dimension() == g->dataType().dimension())));
if ((f->descriptor().type() == DiscreteFunctionType::P0Vector) and
(g->descriptor().type() == DiscreteFunctionType::P0Vector)) {
using DiscreteFunctionResultType = DiscreteFunctionP0<Dimension, double>;
using DiscreteFunctionType = DiscreteFunctionP0Vector<Dimension, double>;
const DiscreteFunctionType& f_vector = dynamic_cast<const DiscreteFunctionType&>(*f);
const DiscreteFunctionType& g_vector = dynamic_cast<const DiscreteFunctionType&>(*g);
if (f_vector.size() != g_vector.size()) {
throw NormalError("operands have different dimension");
} else {
return std::make_shared<const DiscreteFunctionResultType>(dot(f_vector, g_vector));
}
} else {
using DiscreteFunctionResultType = DiscreteFunctionP0<Dimension, double>; using DiscreteFunctionResultType = DiscreteFunctionP0<Dimension, double>;
switch (f->dataType().dimension()) { switch (f->dataType().dimension()) {
...@@ -348,13 +366,16 @@ dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<con ...@@ -348,13 +366,16 @@ dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<con
} }
} }
} }
}
std::shared_ptr<const IDiscreteFunction> std::shared_ptr<const IDiscreteFunction>
dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) dot(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{ {
if ((f->dataType() == ASTNodeDataType::vector_t and f->descriptor().type() == DiscreteFunctionType::P0) and if (((f->descriptor().type() == DiscreteFunctionType::P0Vector) and
(g->descriptor().type() == DiscreteFunctionType::P0Vector)) or
((f->dataType() == ASTNodeDataType::vector_t and f->descriptor().type() == DiscreteFunctionType::P0) and
(g->dataType() == ASTNodeDataType::vector_t and g->descriptor().type() == DiscreteFunctionType::P0) and (g->dataType() == ASTNodeDataType::vector_t and g->descriptor().type() == DiscreteFunctionType::P0) and
(f->dataType().dimension() == g->dataType().dimension())) { (f->dataType().dimension() == g->dataType().dimension()))) {
std::shared_ptr mesh = getCommonMesh({f, g}); std::shared_ptr mesh = getCommonMesh({f, g});
if (mesh.use_count() == 0) { if (mesh.use_count() == 0) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <scheme/IDiscreteFunction.hpp> #include <scheme/IDiscreteFunction.hpp>
#include <algebra/Vector.hpp>
#include <mesh/Connectivity.hpp> #include <mesh/Connectivity.hpp>
#include <mesh/ItemArray.hpp> #include <mesh/ItemArray.hpp>
#include <mesh/Mesh.hpp> #include <mesh/Mesh.hpp>
...@@ -194,6 +195,19 @@ class DiscreteFunctionP0Vector : public IDiscreteFunction ...@@ -194,6 +195,19 @@ class DiscreteFunctionP0Vector : public IDiscreteFunction
return product; return product;
} }
PUGS_INLINE friend DiscreteFunctionP0<Dimension, double>
dot(const DiscreteFunctionP0Vector& f, const DiscreteFunctionP0Vector& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
Assert(f.size() == g.size());
DiscreteFunctionP0<Dimension, double> result{f.m_mesh};
parallel_for(
f.m_mesh->numberOfCells(),
PUGS_LAMBDA(CellId cell_id) { result[cell_id] = dot(Vector{f[cell_id]}, Vector{g[cell_id]}); });
return result;
}
DiscreteFunctionP0Vector(const std::shared_ptr<const MeshType>& mesh, size_t size) DiscreteFunctionP0Vector(const std::shared_ptr<const MeshType>& mesh, size_t size)
: m_mesh{mesh}, m_cell_arrays{mesh->connectivity(), size} : m_mesh{mesh}, m_cell_arrays{mesh->connectivity(), size}
{} {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment