Skip to content
Snippets Groups Projects
Commit 86b7fed3 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add binary operators for Vh variables

This allows standard operations that are already available for basic
scalar types B,N,Z,R and R^d, R^dxd

For instance if a:R, v_d:R^d, M_d: R^dxd and ah:Vh(R), vh_d:Vh(R^d),
Mh_d: Vh(R^dxd), one can write :
a*v_1, M_2*v2, M2*vh_2, a*vh_3, M_2*Mh_2, Mh_2*M_2, a*vh_2, a*Mh_3,...

Invalid constructions are for instance
v_1*a, v_1*M_1, vh_2*M_2, vh_2*Mh_2, a_h * a,...
parent f5fa77f1
Branches
Tags
1 merge request!81Feature/discrete function algebra
#include <language/modules/SchemeModule.hpp> #include <language/modules/SchemeModule.hpp>
#include <language/utils/BinaryOperatorProcessorBuilder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/TypeDescriptor.hpp> #include <language/utils/TypeDescriptor.hpp>
#include <mesh/Mesh.hpp> #include <mesh/Mesh.hpp>
#include <scheme/AcousticSolver.hpp> #include <scheme/AcousticSolver.hpp>
...@@ -284,4 +287,77 @@ SchemeModule::SchemeModule() ...@@ -284,4 +287,77 @@ SchemeModule::SchemeModule()
void void
SchemeModule::registerOperators() const SchemeModule::registerOperators() const
{ {
OperatorRepository& repository = OperatorRepository::instance();
repository.addBinaryOperator<language::plus_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::minus_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::divide_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
bool, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
int64_t, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
uint64_t, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
double, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<1>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<2>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<3>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<1>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<2>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<3>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<1>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<2>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<3>>>());
} }
...@@ -6,9 +6,13 @@ ...@@ -6,9 +6,13 @@
#include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/node_processor/BinaryExpressionProcessor.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/IBinaryOperatorProcessorBuilder.hpp> #include <language/utils/IBinaryOperatorProcessorBuilder.hpp>
#include <language/utils/ParseError.hpp>
#include <type_traits> #include <type_traits>
template <typename DataType>
class DataHandler;
template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT>
class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder
{ {
...@@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil ...@@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil
} }
}; };
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>>
final : public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
const auto& embedded_a = std::get<EmbeddedData>(a);
const auto& embedded_b = std::get<EmbeddedData>(b);
std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();
std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr)));
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final
: public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) {
const auto& a_value = std::get<A_DataT>(a);
const auto& embedded_b = std::get<EmbeddedData>(b);
std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr)));
} else {
static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type");
}
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final
: public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) {
const auto& embedded_a = std::get<EmbeddedData>(a);
const auto& b_value = std::get<B_DataT>(b);
std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value)));
} else {
static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type");
}
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
#endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP #endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP
...@@ -22,6 +22,7 @@ add_library(PugsLanguageUtils ...@@ -22,6 +22,7 @@ add_library(PugsLanguageUtils
BinaryOperatorRegisterForZ.cpp BinaryOperatorRegisterForZ.cpp
DataVariant.cpp DataVariant.cpp
EmbeddedData.cpp EmbeddedData.cpp
EmbeddedIDiscreteFunctionOperators.cpp
FunctionSymbolId.cpp FunctionSymbolId.cpp
IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForN.cpp
IncDecOperatorRegisterForR.cpp IncDecOperatorRegisterForR.cpp
......
This diff is collapsed.
#ifndef EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP
#define EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP
#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <memory>
class IDiscreteFunction;
std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator/(const std::shared_ptr<const IDiscreteFunction>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const double&, const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<1>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<2>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<3>&,
const std::shared_ptr<const IDiscreteFunction>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyVector<1>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyVector<2>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyVector<3>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyMatrix<1>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyMatrix<2>&);
std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&,
const TinyMatrix<3>&);
#endif // EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
template <size_t Dimension, typename DataType> template <size_t Dimension, typename DataType>
class DiscreteFunctionP0 : public IDiscreteFunction class DiscreteFunctionP0 : public IDiscreteFunction
{ {
private: public:
using data_type = DataType;
using MeshType = Mesh<Connectivity<Dimension>>; using MeshType = Mesh<Connectivity<Dimension>>;
private:
std::shared_ptr<const MeshType> m_mesh; std::shared_ptr<const MeshType> m_mesh;
CellValue<DataType> m_cell_values; CellValue<DataType> m_cell_values;
...@@ -53,6 +55,72 @@ class DiscreteFunctionP0 : public IDiscreteFunction ...@@ -53,6 +55,72 @@ class DiscreteFunctionP0 : public IDiscreteFunction
return m_cell_values[cell_id]; return m_cell_values[cell_id];
} }
friend DiscreteFunctionP0
operator+(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0 sum(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + g[cell_id]; });
return sum;
}
friend DiscreteFunctionP0
operator-(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0 difference(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - g[cell_id]; });
return difference;
}
friend DiscreteFunctionP0
operator*(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0 product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; });
return product;
}
template <typename DataType2T>
friend DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})>
operator*(const DiscreteFunctionP0<Dimension, DataType2T>& f, const DiscreteFunctionP0& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; });
return product;
}
friend DiscreteFunctionP0
operator*(const double& a, const DiscreteFunctionP0& f)
{
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0 product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = a * f[cell_id]; });
return product;
}
friend DiscreteFunctionP0
operator/(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
{
Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0 ratio(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / g[cell_id]; });
return ratio;
}
DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh) DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh)
{ {
using MeshDataType = MeshData<Dimension>; using MeshDataType = MeshData<Dimension>;
...@@ -80,4 +148,64 @@ class DiscreteFunctionP0 : public IDiscreteFunction ...@@ -80,4 +148,64 @@ class DiscreteFunctionP0 : public IDiscreteFunction
~DiscreteFunctionP0() = default; ~DiscreteFunctionP0() = default;
}; };
template <size_t Dimension, size_t ValueDimension>
DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>
operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>& f)
{
using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType;
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; });
return product;
}
template <size_t Dimension, size_t ValueDimension>
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f)
{
using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; });
return product;
}
template <size_t Dimension, size_t ValueDimension>
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
operator*(const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f, const TinyMatrix<ValueDimension>& A)
{
using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
return product;
}
template <size_t Dimension, size_t ValueDimension>
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyMatrix<ValueDimension>& A)
{
using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
return product;
}
template <size_t Dimension, size_t ValueDimension>
DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>
operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyVector<ValueDimension>& A)
{
using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType;
std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh);
parallel_for(
mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
return product;
}
#endif // DISCRETE_FUNCTION_P0_HPP #endif // DISCRETE_FUNCTION_P0_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment