diff --git a/src/language/ASTNodeExpressionBuilder.cpp b/src/language/ASTNodeExpressionBuilder.cpp index eac81ff864f9793f039f98ed507c29226d678dcb..fdb9c76a931a9ae7c17df36432f2011efda9ba9d 100644 --- a/src/language/ASTNodeExpressionBuilder.cpp +++ b/src/language/ASTNodeExpressionBuilder.cpp @@ -42,15 +42,15 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n) } else if (n.is_type<language::tuple_expression>()) { switch (n.children.size()) { case 1: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<1>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 1>>(n); break; } case 2: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<2>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 2>>(n); break; } case 3: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<3>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 3>>(n); break; } default: { diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp index c6d7e33304612c63a587eed82181bd3b4f62af15..69cdad7f1cb58ea1b49bf69eec64966ea228ddcd 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp @@ -8,6 +8,7 @@ #include <ASTNodeNaturalConversionChecker.hpp> #include <node_processor/FunctionProcessor.hpp> +#include <node_processor/TupleToVectorProcessor.hpp> template <typename SymbolType> std::unique_ptr<IFunctionArgumentConverter> @@ -335,23 +336,48 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ASTNode& function_image_domain = *function_descriptor.domainMappingNode().children[1]; ASTNode& function_expression = *function_descriptor.definitionNode().children[1]; - if (function_expression.is_type<language::expression_list>()) { - if (function_image_domain.is_type<language::vector_type>()) { - for (size_t i = 0; i < function_expression.children.size(); ++i) { - function_processor->addFunctionExpressionProcessor( - this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node, - *function_expression.children[i])); - } - } else { + if (function_image_domain.is_type<language::vector_type>()) { + ASTNodeDataType vector_type = getVectorDataType(function_image_domain); + + Assert(vector_type.dimension() == function_expression.children.size()); + + for (size_t i = 0; i < vector_type.dimension(); ++i) { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node, + *function_expression.children[i])); + } + + switch (vector_type.dimension()) { + case 1: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 1>>(node, std::move(function_processor)); + break; + } + case 2: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 2>>(node, std::move(function_processor)); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 3>>(node, std::move(function_processor)); + break; + } + default: { + throw parse_error("unexpected error: invalid vector_t dimension", std::vector{node.begin()}); + } + } + } else { + if (function_expression.is_type<language::expression_list>()) { ASTNode& image_domain_node = function_image_domain; for (size_t i = 0; i < function_expression.children.size(); ++i) { add_component_expression(*function_expression.children[i], *image_domain_node.children[i]); } + } else { + add_component_expression(function_expression, function_image_domain); } - } else { - add_component_expression(function_expression, function_image_domain); - } - node.m_node_processor = std::move(function_processor); + node.m_node_processor = std::move(function_processor); + } } diff --git a/src/language/node_processor/TupleToVectorProcessor.hpp b/src/language/node_processor/TupleToVectorProcessor.hpp index 33f6bc30538c3dc70209dab71100af358defe87e..b15d881469e1ce20128249f45536df7b91db8ee2 100644 --- a/src/language/node_processor/TupleToVectorProcessor.hpp +++ b/src/language/node_processor/TupleToVectorProcessor.hpp @@ -3,21 +3,19 @@ #include <node_processor/INodeProcessor.hpp> -#include <node_processor/ASTNodeExpressionListProcessor.hpp> - -template <size_t N> +template <typename TupleProcessorT, size_t N> class TupleToVectorProcessor final : public INodeProcessor { private: ASTNode& m_node; - ASTNodeExpressionListProcessor m_list_processor; + std::unique_ptr<TupleProcessorT> m_tuple_processor; public: DataVariant execute(ExecutionPolicy& exec_policy) { - AggregateDataVariant v = std::get<AggregateDataVariant>(m_list_processor.execute(exec_policy)); + AggregateDataVariant v = std::get<AggregateDataVariant>(m_tuple_processor->execute(exec_policy)); Assert(v.size() == N); @@ -39,7 +37,11 @@ class TupleToVectorProcessor final : public INodeProcessor return DataVariant{std::move(x)}; } - TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_list_processor{node} {} + TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_tuple_processor{std::make_unique<TupleProcessorT>(node)} {} + + TupleToVectorProcessor(ASTNode& node, std::unique_ptr<TupleProcessorT>&& tuple_processor) + : m_node{node}, m_tuple_processor{std::move(tuple_processor)} + {} }; #endif // TUPLE_TO_VECTOR_PROCESSOR_HPP