From d69de46e618011e9fe1ad81c031aebc9df7355ab Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 29 Jan 2020 18:47:38 +0100 Subject: [PATCH] Fix return type handling for vector functions Now, one can write for instance `` let f : R->R^3, x -> (x, 2.3, -x+2); R^3 x = f(0.3); R^3 y = 2*f(1.2) + x + 3*f(-1.4); `` --- src/language/ASTNodeExpressionBuilder.cpp | 6 +-- .../ASTNodeFunctionExpressionBuilder.cpp | 50 ++++++++++++++----- .../node_processor/TupleToVectorProcessor.hpp | 14 +++--- 3 files changed, 49 insertions(+), 21 deletions(-) diff --git a/src/language/ASTNodeExpressionBuilder.cpp b/src/language/ASTNodeExpressionBuilder.cpp index eac81ff86..fdb9c76a9 100644 --- a/src/language/ASTNodeExpressionBuilder.cpp +++ b/src/language/ASTNodeExpressionBuilder.cpp @@ -42,15 +42,15 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n) } else if (n.is_type<language::tuple_expression>()) { switch (n.children.size()) { case 1: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<1>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 1>>(n); break; } case 2: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<2>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 2>>(n); break; } case 3: { - n.m_node_processor = std::make_unique<TupleToVectorProcessor<3>>(n); + n.m_node_processor = std::make_unique<TupleToVectorProcessor<ASTNodeExpressionListProcessor, 3>>(n); break; } default: { diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp index c6d7e3330..69cdad7f1 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp @@ -8,6 +8,7 @@ #include <ASTNodeNaturalConversionChecker.hpp> #include <node_processor/FunctionProcessor.hpp> +#include <node_processor/TupleToVectorProcessor.hpp> template <typename SymbolType> std::unique_ptr<IFunctionArgumentConverter> @@ -335,23 +336,48 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ASTNode& function_image_domain = *function_descriptor.domainMappingNode().children[1]; ASTNode& function_expression = *function_descriptor.definitionNode().children[1]; - if (function_expression.is_type<language::expression_list>()) { - if (function_image_domain.is_type<language::vector_type>()) { - for (size_t i = 0; i < function_expression.children.size(); ++i) { - function_processor->addFunctionExpressionProcessor( - this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node, - *function_expression.children[i])); - } - } else { + if (function_image_domain.is_type<language::vector_type>()) { + ASTNodeDataType vector_type = getVectorDataType(function_image_domain); + + Assert(vector_type.dimension() == function_expression.children.size()); + + for (size_t i = 0; i < vector_type.dimension(); ++i) { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node, + *function_expression.children[i])); + } + + switch (vector_type.dimension()) { + case 1: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 1>>(node, std::move(function_processor)); + break; + } + case 2: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 2>>(node, std::move(function_processor)); + break; + } + case 3: { + node.m_node_processor = + std::make_unique<TupleToVectorProcessor<FunctionProcessor, 3>>(node, std::move(function_processor)); + break; + } + default: { + throw parse_error("unexpected error: invalid vector_t dimension", std::vector{node.begin()}); + } + } + } else { + if (function_expression.is_type<language::expression_list>()) { ASTNode& image_domain_node = function_image_domain; for (size_t i = 0; i < function_expression.children.size(); ++i) { add_component_expression(*function_expression.children[i], *image_domain_node.children[i]); } + } else { + add_component_expression(function_expression, function_image_domain); } - } else { - add_component_expression(function_expression, function_image_domain); - } - node.m_node_processor = std::move(function_processor); + node.m_node_processor = std::move(function_processor); + } } diff --git a/src/language/node_processor/TupleToVectorProcessor.hpp b/src/language/node_processor/TupleToVectorProcessor.hpp index 33f6bc305..b15d88146 100644 --- a/src/language/node_processor/TupleToVectorProcessor.hpp +++ b/src/language/node_processor/TupleToVectorProcessor.hpp @@ -3,21 +3,19 @@ #include <node_processor/INodeProcessor.hpp> -#include <node_processor/ASTNodeExpressionListProcessor.hpp> - -template <size_t N> +template <typename TupleProcessorT, size_t N> class TupleToVectorProcessor final : public INodeProcessor { private: ASTNode& m_node; - ASTNodeExpressionListProcessor m_list_processor; + std::unique_ptr<TupleProcessorT> m_tuple_processor; public: DataVariant execute(ExecutionPolicy& exec_policy) { - AggregateDataVariant v = std::get<AggregateDataVariant>(m_list_processor.execute(exec_policy)); + AggregateDataVariant v = std::get<AggregateDataVariant>(m_tuple_processor->execute(exec_policy)); Assert(v.size() == N); @@ -39,7 +37,11 @@ class TupleToVectorProcessor final : public INodeProcessor return DataVariant{std::move(x)}; } - TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_list_processor{node} {} + TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_tuple_processor{std::make_unique<TupleProcessorT>(node)} {} + + TupleToVectorProcessor(ASTNode& node, std::unique_ptr<TupleProcessorT>&& tuple_processor) + : m_node{node}, m_tuple_processor{std::move(tuple_processor)} + {} }; #endif // TUPLE_TO_VECTOR_PROCESSOR_HPP -- GitLab