From 86b7fed30e41e755f2fcb7297cf04824e1b84bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Sun, 21 Mar 2021 19:14:16 +0100 Subject: [PATCH] Add binary operators for Vh variables This allows standard operations that are already available for basic scalar types B,N,Z,R and R^d, R^dxd For instance if a:R, v_d:R^d, M_d: R^dxd and ah:Vh(R), vh_d:Vh(R^d), Mh_d: Vh(R^dxd), one can write : a*v_1, M_2*v2, M2*vh_2, a*vh_3, M_2*Mh_2, Mh_2*M_2, a*vh_2, a*Mh_3,... Invalid constructions are for instance v_1*a, v_1*M_1, vh_2*M_2, vh_2*Mh_2, a_h * a,... --- src/language/modules/SchemeModule.cpp | 76 ++ .../utils/BinaryOperatorProcessorBuilder.hpp | 113 +++ src/language/utils/CMakeLists.txt | 1 + .../EmbeddedIDiscreteFunctionOperators.cpp | 696 ++++++++++++++++++ .../EmbeddedIDiscreteFunctionOperators.hpp | 52 ++ src/scheme/DiscreteFunctionP0.hpp | 132 +++- 6 files changed, 1068 insertions(+), 2 deletions(-) create mode 100644 src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp create mode 100644 src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index f0748f270..4d1a86632 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -1,6 +1,9 @@ #include <language/modules/SchemeModule.hpp> +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/TypeDescriptor.hpp> #include <mesh/Mesh.hpp> #include <scheme/AcousticSolver.hpp> @@ -284,4 +287,77 @@ SchemeModule::SchemeModule() void SchemeModule::registerOperators() const { + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::plus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::divide_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + bool, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + int64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + uint64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + double, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<1>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<2>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<3>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<3>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<3>>>()); } diff --git a/src/language/utils/BinaryOperatorProcessorBuilder.hpp b/src/language/utils/BinaryOperatorProcessorBuilder.hpp index b8cb572d1..eb07b35d0 100644 --- a/src/language/utils/BinaryOperatorProcessorBuilder.hpp +++ b/src/language/utils/BinaryOperatorProcessorBuilder.hpp @@ -6,9 +6,13 @@ #include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/IBinaryOperatorProcessorBuilder.hpp> +#include <language/utils/ParseError.hpp> #include <type_traits> +template <typename DataType> +class DataHandler; + template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder { @@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil } }; +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>> + final : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr))); + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) { + const auto& a_value = std::get<A_DataT>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr))); + } else { + static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& b_value = std::get<B_DataT>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value))); + } else { + static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + #endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 32a75c66f..80640af6e 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(PugsLanguageUtils BinaryOperatorRegisterForZ.cpp DataVariant.cpp EmbeddedData.cpp + EmbeddedIDiscreteFunctionOperators.cpp FunctionSymbolId.cpp IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForR.cpp diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp new file mode 100644 index 000000000..db7a609b9 --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp @@ -0,0 +1,696 @@ +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> + +#include <language/node_processor/BinaryExpressionProcessor.hpp> +#include <scheme/DiscreteFunctionP0.hpp> +#include <scheme/IDiscreteFunction.hpp> +#include <utils/Exceptions.hpp> + +template <typename T> +PUGS_INLINE std::string +name(const T&) +{ + return dataTypeName(ast_node_data_type_from<T>); +} + +template <> +PUGS_INLINE std::string +name(const IDiscreteFunction& f) +{ + return "Vh(" + dataTypeName(f.dataType()) + ")"; +} + +template <> +PUGS_INLINE std::string +name(const std::shared_ptr<const IDiscreteFunction>& f) +{ + return "Vh(" + dataTypeName(f->dataType()) + ")"; +} + +template <typename LHS_T, typename RHS_T> +PUGS_INLINE std::string +invalid_operands(const LHS_T& f, const RHS_T& g) +{ + std::ostringstream os; + os << "undefined binary operator\n"; + os << "note: incompatible operand types " << name(f) << " and " << name(g); + return os.str(); +} + +template <typename BinOperatorT, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const DiscreteFunctionT& lhs, const DiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using data_type = typename DiscreteFunctionT::data_type; + if constexpr ((std::is_same_v<language::multiply_op, BinOperatorT> and is_tiny_vector_v<data_type>) or + (std::is_same_v<language::divide_op, BinOperatorT> and not std::is_arithmetic_v<data_type>)) { + throw NormalError(invalid_operands(lhs, rhs)); + } else { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() == g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case ASTNodeDataType::vector_t: { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("discrete functions defined on different meshes"); + } + if (f->dataType() != g->dataType()) { + throw NormalError(invalid_operands(f, g)); + } + + switch (f->mesh()->dimension()) { + case 1: { + return innerCompositionLaw<BinOperatorT, 1>(f, g); + } + case 2: { + return innerCompositionLaw<BinOperatorT, 2>(f, g); + } + case 3: { + return innerCompositionLaw<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename LeftDiscreteFunctionT, typename RightDiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const LeftDiscreteFunctionT& lhs, const RightDiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using lhs_data_type = typename LeftDiscreteFunctionT::data_type; + using rhs_data_type = typename RightDiscreteFunctionT::data_type; + + static_assert(not std::is_same_v<rhs_data_type, lhs_data_type>, + "use innerCompositionLaw when data types are the same"); + + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); +} + +template <typename BinOperatorT, size_t Dimension, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(fh.mesh() == g->mesh()); + Assert(fh.dataType() != g->dataType()); + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + switch (g->dataType()) { + case ASTNodeDataType::double_t: { + if constexpr (not std::is_same_v<lhs_data_type, double>) { + if constexpr (not is_tiny_matrix_v<lhs_data_type>) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } else { + throw UnexpectedError("should have called innerCompositionLaw"); + } + } + case ASTNodeDataType::vector_t: { + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().dimension()) { + case 1: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<1>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 2: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<2>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 3: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<3>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case ASTNodeDataType::matrix_t: { + Assert(g->dataType().nbRows() == g->dataType().nbColumns()); + if constexpr (std::is_same_v<lhs_data_type, double> and std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().nbRows()) { + case 1: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 2: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 3: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() != g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("functions defined on different meshes"); + } + + Assert(f->dataType() != g->dataType(), "should call inner composition instead"); + + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperation<BinOperatorT, 1>(f, g); + } + case 2: { + return applyBinaryOperation<BinOperatorT, 2>(f, g); + } + case 3: { + return applyBinaryOperation<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator+(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::plus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator-(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::minus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::multiply_op>(f, g); + } else { + return applyBinaryOperation<language::multiply_op>(f, g); + } +} + +std::shared_ptr<const IDiscreteFunction> +operator/(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::divide_op>(f, g); + } else { + return applyBinaryOperation<language::divide_op>(f, g); + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const DiscreteFunctionT& f) +{ + using lhs_data_type = std::decay_t<DataType>; + using rhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (std::is_same_v<lhs_data_type, double>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else if constexpr (is_tiny_matrix_v<lhs_data_type> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else { + throw NormalError(invalid_operands(a, f)); + } + } else { + throw NormalError(invalid_operands(a, f)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case ASTNodeDataType::vector_t: { + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().dimension()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(a, f)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 1>(a, f); + } + case 2: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 2>(a, f); + } + case 3: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 3>(a, f); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const DiscreteFunctionT& f, const DataType& a) +{ + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + using rhs_data_type = std::decay_t<DataType>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (is_tiny_matrix_v<lhs_data_type> and is_tiny_matrix_v<rhs_data_type>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else if constexpr (std::is_same_v<lhs_data_type, double> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else { + throw NormalError(invalid_operands(f, a)); + } + } else { + throw NormalError(invalid_operands(f, a)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(f, a)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 1>(f, a); + } + case 2: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 2>(f, a); + } + case 3: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 3>(f, a); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const double& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(a, f); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<1>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<2>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<3>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<1>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<2>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<3>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<1>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<2>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<3>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp new file mode 100644 index 000000000..796982800 --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp @@ -0,0 +1,52 @@ +#ifndef EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP +#define EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP + +#include <algebra/TinyMatrix.hpp> +#include <algebra/TinyVector.hpp> + +#include <memory> + +class IDiscreteFunction; + +std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator/(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const double&, const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<1>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<2>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<3>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<3>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<3>&); + +#endif // EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp index c51e8e09d..53037ea72 100644 --- a/src/scheme/DiscreteFunctionP0.hpp +++ b/src/scheme/DiscreteFunctionP0.hpp @@ -13,9 +13,11 @@ template <size_t Dimension, typename DataType> class DiscreteFunctionP0 : public IDiscreteFunction { - private: - using MeshType = Mesh<Connectivity<Dimension>>; + public: + using data_type = DataType; + using MeshType = Mesh<Connectivity<Dimension>>; + private: std::shared_ptr<const MeshType> m_mesh; CellValue<DataType> m_cell_values; @@ -53,6 +55,72 @@ class DiscreteFunctionP0 : public IDiscreteFunction return m_cell_values[cell_id]; } + friend DiscreteFunctionP0 + operator+(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 sum(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + g[cell_id]; }); + return sum; + } + + friend DiscreteFunctionP0 + operator-(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 difference(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - g[cell_id]; }); + return difference; + } + + friend DiscreteFunctionP0 + operator*(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + template <typename DataType2T> + friend DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> + operator*(const DiscreteFunctionP0<Dimension, DataType2T>& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator*(const double& a, const DiscreteFunctionP0& f) + { + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = a * f[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator/(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 ratio(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / g[cell_id]; }); + return ratio; + } + DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh) { using MeshDataType = MeshData<Dimension>; @@ -80,4 +148,64 @@ class DiscreteFunctionP0 : public IDiscreteFunction ~DiscreteFunctionP0() = default; }; +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyVector<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + #endif // DISCRETE_FUNCTION_P0_HPP -- GitLab