#ifndef BINARY_OPERATOR_PROCESSOR_BUILDER_HPP #define BINARY_OPERATOR_PROCESSOR_BUILDER_HPP #include <algebra/TinyVector.hpp> #include <language/PEGGrammar.hpp> #include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/IBinaryOperatorProcessorBuilder.hpp> #include <language/utils/ParseError.hpp> #include <type_traits> template <typename DataType> class DataHandler; template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder { public: BinaryOperatorProcessorBuilder() = default; ASTNodeDataType getDataTypeOfA() const { return ast_node_data_type_from<A_DataT>; } ASTNodeDataType getDataTypeOfB() const { return ast_node_data_type_from<B_DataT>; } ASTNodeDataType getReturnValueType() const { return ast_node_data_type_from<ValueT>; } std::unique_ptr<INodeProcessor> getNodeProcessor(ASTNode& node) const { return std::make_unique<BinaryExpressionProcessor<OperatorT, ValueT, A_DataT, B_DataT>>(node); } }; template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>> final : public INodeProcessor { private: ASTNode& m_node; PUGS_INLINE DataVariant _eval(const DataVariant& a, const DataVariant& b) { const auto& embedded_a = std::get<EmbeddedData>(a); const auto& embedded_b = std::get<EmbeddedData>(b); std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr))); } public: DataVariant execute(ExecutionPolicy& exec_policy) { try { return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); } catch (const NormalError& error) { throw ParseError(error.what(), m_node.begin()); } } BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} }; template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final : public INodeProcessor { private: ASTNode& m_node; PUGS_INLINE DataVariant _eval(const DataVariant& a, const DataVariant& b) { if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) { const auto& a_value = std::get<A_DataT>(a); const auto& embedded_b = std::get<EmbeddedData>(b); std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr))); } else { static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type"); } } public: DataVariant execute(ExecutionPolicy& exec_policy) { try { return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); } catch (const NormalError& error) { throw ParseError(error.what(), m_node.begin()); } } BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} }; template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final : public INodeProcessor { private: ASTNode& m_node; PUGS_INLINE DataVariant _eval(const DataVariant& a, const DataVariant& b) { if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) { const auto& embedded_a = std::get<EmbeddedData>(a); const auto& b_value = std::get<B_DataT>(b); std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value))); } else { static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type"); } } public: DataVariant execute(ExecutionPolicy& exec_policy) { try { return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); } catch (const NormalError& error) { throw ParseError(error.what(), m_node.begin()); } } BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} }; #endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP