diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index f0748f27040df92ed0237c510db4a6ed6015de28..4d1a866326a46c63f71b9b45b5188d7aa11746ca 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -1,6 +1,9 @@ #include <language/modules/SchemeModule.hpp> +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/TypeDescriptor.hpp> #include <mesh/Mesh.hpp> #include <scheme/AcousticSolver.hpp> @@ -284,4 +287,77 @@ SchemeModule::SchemeModule() void SchemeModule::registerOperators() const { + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::plus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::divide_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + bool, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + int64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + uint64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + double, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<1>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<2>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<3>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<3>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<3>>>()); } diff --git a/src/language/utils/BinaryOperatorProcessorBuilder.hpp b/src/language/utils/BinaryOperatorProcessorBuilder.hpp index b8cb572d1e79c1882ed0288d79a4b73c6c130e5c..eb07b35d028f3eaea3f59537753dfd936a185837 100644 --- a/src/language/utils/BinaryOperatorProcessorBuilder.hpp +++ b/src/language/utils/BinaryOperatorProcessorBuilder.hpp @@ -6,9 +6,13 @@ #include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/IBinaryOperatorProcessorBuilder.hpp> +#include <language/utils/ParseError.hpp> #include <type_traits> +template <typename DataType> +class DataHandler; + template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder { @@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil } }; +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>> + final : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr))); + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) { + const auto& a_value = std::get<A_DataT>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr))); + } else { + static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& b_value = std::get<B_DataT>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value))); + } else { + static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + #endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 32a75c66f0dd73cc10c09ca2c6aecf640962c45c..80640af6e39ee7febd53562c4c89bc1d2bd9b8c3 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(PugsLanguageUtils BinaryOperatorRegisterForZ.cpp DataVariant.cpp EmbeddedData.cpp + EmbeddedIDiscreteFunctionOperators.cpp FunctionSymbolId.cpp IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForR.cpp diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db7a609b9449d27380c2030c75e938554e6a095b --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp @@ -0,0 +1,696 @@ +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> + +#include <language/node_processor/BinaryExpressionProcessor.hpp> +#include <scheme/DiscreteFunctionP0.hpp> +#include <scheme/IDiscreteFunction.hpp> +#include <utils/Exceptions.hpp> + +template <typename T> +PUGS_INLINE std::string +name(const T&) +{ + return dataTypeName(ast_node_data_type_from<T>); +} + +template <> +PUGS_INLINE std::string +name(const IDiscreteFunction& f) +{ + return "Vh(" + dataTypeName(f.dataType()) + ")"; +} + +template <> +PUGS_INLINE std::string +name(const std::shared_ptr<const IDiscreteFunction>& f) +{ + return "Vh(" + dataTypeName(f->dataType()) + ")"; +} + +template <typename LHS_T, typename RHS_T> +PUGS_INLINE std::string +invalid_operands(const LHS_T& f, const RHS_T& g) +{ + std::ostringstream os; + os << "undefined binary operator\n"; + os << "note: incompatible operand types " << name(f) << " and " << name(g); + return os.str(); +} + +template <typename BinOperatorT, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const DiscreteFunctionT& lhs, const DiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using data_type = typename DiscreteFunctionT::data_type; + if constexpr ((std::is_same_v<language::multiply_op, BinOperatorT> and is_tiny_vector_v<data_type>) or + (std::is_same_v<language::divide_op, BinOperatorT> and not std::is_arithmetic_v<data_type>)) { + throw NormalError(invalid_operands(lhs, rhs)); + } else { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() == g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case ASTNodeDataType::vector_t: { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("discrete functions defined on different meshes"); + } + if (f->dataType() != g->dataType()) { + throw NormalError(invalid_operands(f, g)); + } + + switch (f->mesh()->dimension()) { + case 1: { + return innerCompositionLaw<BinOperatorT, 1>(f, g); + } + case 2: { + return innerCompositionLaw<BinOperatorT, 2>(f, g); + } + case 3: { + return innerCompositionLaw<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename LeftDiscreteFunctionT, typename RightDiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const LeftDiscreteFunctionT& lhs, const RightDiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using lhs_data_type = typename LeftDiscreteFunctionT::data_type; + using rhs_data_type = typename RightDiscreteFunctionT::data_type; + + static_assert(not std::is_same_v<rhs_data_type, lhs_data_type>, + "use innerCompositionLaw when data types are the same"); + + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); +} + +template <typename BinOperatorT, size_t Dimension, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(fh.mesh() == g->mesh()); + Assert(fh.dataType() != g->dataType()); + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + switch (g->dataType()) { + case ASTNodeDataType::double_t: { + if constexpr (not std::is_same_v<lhs_data_type, double>) { + if constexpr (not is_tiny_matrix_v<lhs_data_type>) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } else { + throw UnexpectedError("should have called innerCompositionLaw"); + } + } + case ASTNodeDataType::vector_t: { + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().dimension()) { + case 1: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<1>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 2: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<2>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 3: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<3>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case ASTNodeDataType::matrix_t: { + Assert(g->dataType().nbRows() == g->dataType().nbColumns()); + if constexpr (std::is_same_v<lhs_data_type, double> and std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().nbRows()) { + case 1: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 2: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 3: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() != g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("functions defined on different meshes"); + } + + Assert(f->dataType() != g->dataType(), "should call inner composition instead"); + + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperation<BinOperatorT, 1>(f, g); + } + case 2: { + return applyBinaryOperation<BinOperatorT, 2>(f, g); + } + case 3: { + return applyBinaryOperation<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator+(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::plus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator-(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::minus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::multiply_op>(f, g); + } else { + return applyBinaryOperation<language::multiply_op>(f, g); + } +} + +std::shared_ptr<const IDiscreteFunction> +operator/(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::divide_op>(f, g); + } else { + return applyBinaryOperation<language::divide_op>(f, g); + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const DiscreteFunctionT& f) +{ + using lhs_data_type = std::decay_t<DataType>; + using rhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (std::is_same_v<lhs_data_type, double>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else if constexpr (is_tiny_matrix_v<lhs_data_type> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else { + throw NormalError(invalid_operands(a, f)); + } + } else { + throw NormalError(invalid_operands(a, f)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case ASTNodeDataType::vector_t: { + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().dimension()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(a, f)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 1>(a, f); + } + case 2: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 2>(a, f); + } + case 3: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 3>(a, f); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const DiscreteFunctionT& f, const DataType& a) +{ + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + using rhs_data_type = std::decay_t<DataType>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (is_tiny_matrix_v<lhs_data_type> and is_tiny_matrix_v<rhs_data_type>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else if constexpr (std::is_same_v<lhs_data_type, double> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else { + throw NormalError(invalid_operands(f, a)); + } + } else { + throw NormalError(invalid_operands(f, a)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(f, a)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 1>(f, a); + } + case 2: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 2>(f, a); + } + case 3: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 3>(f, a); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const double& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(a, f); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<1>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<2>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<3>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<1>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<2>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<3>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<1>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<2>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<3>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7969828000c4440ea9aa9d36014ee18d55d643c1 --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp @@ -0,0 +1,52 @@ +#ifndef EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP +#define EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP + +#include <algebra/TinyMatrix.hpp> +#include <algebra/TinyVector.hpp> + +#include <memory> + +class IDiscreteFunction; + +std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator/(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const double&, const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<1>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<2>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<3>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<3>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<3>&); + +#endif // EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp index c51e8e09dbb78dc6f40f4fbb18eb737ddb7e1a30..53037ea72b28813cd2b2e619574f8faefa7484f0 100644 --- a/src/scheme/DiscreteFunctionP0.hpp +++ b/src/scheme/DiscreteFunctionP0.hpp @@ -13,9 +13,11 @@ template <size_t Dimension, typename DataType> class DiscreteFunctionP0 : public IDiscreteFunction { - private: - using MeshType = Mesh<Connectivity<Dimension>>; + public: + using data_type = DataType; + using MeshType = Mesh<Connectivity<Dimension>>; + private: std::shared_ptr<const MeshType> m_mesh; CellValue<DataType> m_cell_values; @@ -53,6 +55,72 @@ class DiscreteFunctionP0 : public IDiscreteFunction return m_cell_values[cell_id]; } + friend DiscreteFunctionP0 + operator+(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 sum(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + g[cell_id]; }); + return sum; + } + + friend DiscreteFunctionP0 + operator-(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 difference(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - g[cell_id]; }); + return difference; + } + + friend DiscreteFunctionP0 + operator*(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + template <typename DataType2T> + friend DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> + operator*(const DiscreteFunctionP0<Dimension, DataType2T>& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator*(const double& a, const DiscreteFunctionP0& f) + { + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = a * f[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator/(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 ratio(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / g[cell_id]; }); + return ratio; + } + DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh) { using MeshDataType = MeshData<Dimension>; @@ -80,4 +148,64 @@ class DiscreteFunctionP0 : public IDiscreteFunction ~DiscreteFunctionP0() = default; }; +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyVector<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + #endif // DISCRETE_FUNCTION_P0_HPP