diff --git a/src/language/modules/CMakeLists.txt b/src/language/modules/CMakeLists.txt index c21c8cb962798579c93ee52f143c6d5de679102a..4d8f15775bc356f8ce7d85ce8f3a27bd35c45e22 100644 --- a/src/language/modules/CMakeLists.txt +++ b/src/language/modules/CMakeLists.txt @@ -9,6 +9,7 @@ add_library(PugsLanguageModules MeshModule.cpp ModuleRepository.cpp SchemeModule.cpp + UnaryOperatorRegisterForVh.cpp UtilsModule.cpp WriterModule.cpp ) diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index 7b175e0893716318665b4d4b83aa6e011133947a..db1988abac81bb0613e473bfa7429022d0f8f90d 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -1,6 +1,7 @@ #include <language/modules/SchemeModule.hpp> #include <language/modules/BinaryOperatorRegisterForVh.hpp> +#include <language/modules/UnaryOperatorRegisterForVh.hpp> #include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/TypeDescriptor.hpp> @@ -301,4 +302,5 @@ void SchemeModule::registerOperators() const { BinaryOperatorRegisterForVh{}; + UnaryOperatorRegisterForVh{}; } diff --git a/src/language/modules/UnaryOperatorRegisterForVh.cpp b/src/language/modules/UnaryOperatorRegisterForVh.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b1dc85aa9dd11be2f5a6d159d8f1a70254d68c7b --- /dev/null +++ b/src/language/modules/UnaryOperatorRegisterForVh.cpp @@ -0,0 +1,24 @@ +#include <language/modules/UnaryOperatorRegisterForVh.hpp> + +#include <language/modules/SchemeModule.hpp> +#include <language/utils/DataHandler.hpp> +#include <language/utils/DataVariant.hpp> +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> +#include <language/utils/OperatorRepository.hpp> +#include <language/utils/UnaryOperatorProcessorBuilder.hpp> +#include <scheme/IDiscreteFunction.hpp> + +void +UnaryOperatorRegisterForVh::_register_unary_minus() +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addUnaryOperator<language::unary_minus>( + std::make_shared<UnaryOperatorProcessorBuilder<language::unary_minus, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); +} + +UnaryOperatorRegisterForVh::UnaryOperatorRegisterForVh() +{ + this->_register_unary_minus(); +} diff --git a/src/language/modules/UnaryOperatorRegisterForVh.hpp b/src/language/modules/UnaryOperatorRegisterForVh.hpp new file mode 100644 index 0000000000000000000000000000000000000000..65ea35bb6d661e6257845141b4301ca8921bb169 --- /dev/null +++ b/src/language/modules/UnaryOperatorRegisterForVh.hpp @@ -0,0 +1,13 @@ +#ifndef UNARY_OPERATOR_REGISTER_FOR_VH_HPP +#define UNARY_OPERATOR_REGISTER_FOR_VH_HPP + +class UnaryOperatorRegisterForVh +{ + private: + void _register_unary_minus(); + + public: + UnaryOperatorRegisterForVh(); +}; + +#endif // UNARY_OPERATOR_REGISTER_FOR_VH_HPP diff --git a/src/language/node_processor/UnaryExpressionProcessor.hpp b/src/language/node_processor/UnaryExpressionProcessor.hpp index 055e3a028be043c8a30189b7e4cccddd99aeef4a..cdc4fc54111fd3a52854160347b5753e3787bcd5 100644 --- a/src/language/node_processor/UnaryExpressionProcessor.hpp +++ b/src/language/node_processor/UnaryExpressionProcessor.hpp @@ -5,6 +5,9 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +template <typename DataType> +class DataHandler; + template <typename Op> struct UnaryOp; @@ -52,4 +55,30 @@ class UnaryExpressionProcessor final : public INodeProcessor UnaryExpressionProcessor(ASTNode& node) : m_node{node} {} }; +template <typename UnaryOpT, typename ValueT, typename DataT> +class UnaryExpressionProcessor<UnaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<DataT>> final : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a) + { + const auto& embedded_a = std::get<EmbeddedData>(a); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<DataT>&>(embedded_a.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(UnaryOp<UnaryOpT>().eval(a_ptr))); + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + return this->_eval(m_node.children[0]->execute(exec_policy)); + } + + UnaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + #endif // UNARY_EXPRESSION_PROCESSOR_HPP diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp index ecbc5e15c2ef7e14f606ceb9309642095b439172..a9691615ba86788e978bdc7696ba705afd20d236 100644 --- a/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp @@ -1,6 +1,7 @@ #include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> #include <language/node_processor/BinaryExpressionProcessor.hpp> +#include <language/node_processor/UnaryExpressionProcessor.hpp> #include <scheme/DiscreteFunctionP0.hpp> #include <scheme/DiscreteFunctionP0Vector.hpp> #include <scheme/IDiscreteFunction.hpp> @@ -24,7 +25,25 @@ PUGS_INLINE bool isSameDiscretization(const IDiscreteFunction& f, const IDiscreteFunction& g) { - return (f.dataType() == g.dataType()) and (f.descriptor().type() == g.descriptor().type()); + if ((f.dataType() == g.dataType()) and (f.descriptor().type() == g.descriptor().type())) { + switch (f.dataType()) { + case ASTNodeDataType::double_t: { + return true; + } + case ASTNodeDataType::vector_t: { + return f.dataType().dimension() == g.dataType().dimension(); + } + case ASTNodeDataType::matrix_t: { + return (f.dataType().nbRows() == g.dataType().nbRows()) and + (f.dataType().nbColumns() == g.dataType().nbColumns()); + } + default: { + throw UnexpectedError("invalid data type " + operand_type_name(f)); + } + } + } else { + return false; + } } PUGS_INLINE @@ -32,7 +51,7 @@ bool isSameDiscretization(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) { - return (f->dataType() == g->dataType()) and (f->descriptor().type() == g->descriptor().type()); + return isSameDiscretization(*f, *g); } template <typename LHS_T, typename RHS_T> @@ -45,6 +64,116 @@ invalid_operands(const LHS_T& f, const RHS_T& g) return os.str(); } +// unary operators +template <typename UnaryOperatorT, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyUnaryOperation(const DiscreteFunctionT& f) +{ + return std::make_shared<decltype(UnaryOp<UnaryOperatorT>{}.eval(f))>(UnaryOp<UnaryOperatorT>{}.eval(f)); +} + +template <typename UnaryOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +applyUnaryOperation(const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->descriptor().type()) { + case DiscreteFunctionType::P0: { + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + case ASTNodeDataType::vector_t: { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + default: { + throw UnexpectedError("invalid operand type " + operand_type_name(f)); + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + default: { + throw UnexpectedError("invalid operand type " + operand_type_name(f)); + } + } + } + default: { + throw UnexpectedError("invalid operand type " + operand_type_name(f)); + } + } + break; + } + case DiscreteFunctionType::P0Vector: { + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*f); + return applyUnaryOperation<UnaryOperatorT>(fh); + } + default: { + throw UnexpectedError("invalid operand type " + operand_type_name(f)); + } + } + break; + } + default: { + throw UnexpectedError("invalid operand type " + operand_type_name(f)); + } + } +} + +template <typename UnaryOperatorT> +std::shared_ptr<const IDiscreteFunction> +applyUnaryOperation(const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyUnaryOperation<UnaryOperatorT, 1>(f); + } + case 2: { + return applyUnaryOperation<UnaryOperatorT, 2>(f); + } + case 3: { + return applyUnaryOperation<UnaryOperatorT, 3>(f); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator-(const std::shared_ptr<const IDiscreteFunction>& f) +{ + return applyUnaryOperation<language::unary_minus>(f); +} + +// binary operators + template <typename BinOperatorT, typename DiscreteFunctionT> std::shared_ptr<const IDiscreteFunction> innerCompositionLaw(const DiscreteFunctionT& lhs, const DiscreteFunctionT& rhs) @@ -142,12 +271,12 @@ innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, return innerCompositionLaw<BinOperatorT>(fh, gh); } default: { - throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + throw UnexpectedError("invalid data type " + operand_type_name(f)); } } } default: { - throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + throw UnexpectedError("invalid data type " + operand_type_name(f)); } } } @@ -256,7 +385,7 @@ applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const ID } } default: { - throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + throw UnexpectedError("invalid rhs data type " + operand_type_name(g)); } } } else { @@ -283,7 +412,7 @@ applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const ID return applyBinaryOperation<BinOperatorT>(fh, gh); } default: { - throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + throw UnexpectedError("invalid rhs data type " + operand_type_name(g)); } } } else { @@ -291,7 +420,7 @@ applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const ID } } default: { - throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + throw UnexpectedError("invalid rhs data type " + operand_type_name(g)); } } } @@ -329,7 +458,7 @@ applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } @@ -369,13 +498,21 @@ applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) { - return innerCompositionLaw<language::plus_op>(f, g); + if (isSameDiscretization(f, g)) { + return innerCompositionLaw<language::plus_op>(f, g); + } else { + throw NormalError(invalid_operands(f, g)); + } } std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) { - return innerCompositionLaw<language::minus_op>(f, g); + if (isSameDiscretization(f, g)) { + return innerCompositionLaw<language::minus_op>(f, g); + } else { + throw NormalError(invalid_operands(f, g)); + } } std::shared_ptr<const IDiscreteFunction> @@ -502,7 +639,7 @@ applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<co } } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } else { @@ -520,7 +657,7 @@ applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<co return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } @@ -554,7 +691,7 @@ applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<co } } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } else { @@ -572,7 +709,7 @@ applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<co return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } @@ -671,7 +808,7 @@ applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunct } } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } else { @@ -689,7 +826,7 @@ applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunct return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); } default: { - throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + throw UnexpectedError("invalid lhs data type " + operand_type_name(f)); } } } diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp index 128746175ccd894957382d13d5a15ce8e8ff7a14..f797a95d03e19a796d9af965d81bbd3970c6cddd 100644 --- a/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp @@ -8,6 +8,9 @@ class IDiscreteFunction; +// unary minus +std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>&); + // sum std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>&, const std::shared_ptr<const IDiscreteFunction>&);