From bc6a2d3c03a98c890ecd32b4d411dd17b58f0d12 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 28 Jan 2020 19:13:39 +0100 Subject: [PATCH] Add tuple expressions for vector initialization This allows to write `` let f : R -> R^3, x -> (x, 2*x+1, -2); R^3 x = f(2.3); `` or even `` let f : R -> R^3*R, x -> ((x, 2*x+1, -2), x-1); R^3*R (x,t) = f(2.3); `` Note that it is not completely functional since `` let f : R -> R^3, x -> (x, 2*x+1, -2); R^3 x = 2*f(2.3); `` does not work! The return type of `f(2.3)` is incorrect (should be an R^3 but remains a `typename` which is improper for these calculations). --- src/language/ASTBuilder.cpp | 1 + src/language/ASTNodeDataTypeBuilder.cpp | 49 ++++++++++++++----- src/language/ASTNodeExpressionBuilder.cpp | 19 +++++++ .../ASTNodeFunctionExpressionBuilder.cpp | 15 ++++-- src/language/PEGGrammar.hpp | 4 +- .../node_processor/TupleToVectorProcessor.hpp | 45 +++++++++++++++++ 6 files changed, 115 insertions(+), 18 deletions(-) create mode 100644 src/language/node_processor/TupleToVectorProcessor.hpp diff --git a/src/language/ASTBuilder.cpp b/src/language/ASTBuilder.cpp index 6bacf4c07..eb853444e 100644 --- a/src/language/ASTBuilder.cpp +++ b/src/language/ASTBuilder.cpp @@ -223,6 +223,7 @@ using selector = parse_tree::selector< N_set, Z_set, R_set, + tuple_expression, vector_type, string_type, cout_kw, diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp index c1ad69879..0b4de8331 100644 --- a/src/language/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ASTNodeDataTypeBuilder.cpp @@ -94,6 +94,21 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const n.m_data_type = ASTNodeDataType::int_t; } else if (n.is_type<language::vector_type>()) { n.m_data_type = getVectorDataType(n); + + } else if (n.is_type<language::tuple_expression>()) { + for (auto&& child : n.children) { + this->_buildNodeDataTypes(*child); + } + for (auto&& child : n.children) { + ASTNodeNaturalConversionChecker{*child, child->m_data_type, ASTNodeDataType::double_t}; + } + + if (n.children.size() <= 3) { + n.m_data_type = ASTNodeDataType{ASTNodeDataType::vector_t, n.children.size()}; + } else { + throw parse_error("invalid vector dimension (must be lesser than 3)", n.begin()); + } + } else if (n.is_type<language::literal>()) { n.m_data_type = ASTNodeDataType::string_t; } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) { @@ -194,23 +209,33 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1]; + this->_buildNodeDataTypes(image_domain_node); + for (auto& child : image_domain_node.children) { + this->_buildNodeDataTypes(*child); + } + const size_t nb_image_domains = (image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1; const size_t nb_image_expressions = (image_expression_node.is_type<language::expression_list>()) ? image_expression_node.children.size() : 1; if (nb_image_domains != nb_image_expressions) { - std::ostringstream message; - message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow - << image_domain_node.string() << rang::style::reset << rang::style::bold - << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow - << image_expression_node.string() << rang::style::reset << std::ends; - throw parse_error(message.str(), image_domain_node.begin()); - } - - this->_buildNodeDataTypes(image_domain_node); - for (auto& child : image_domain_node.children) { - this->_buildNodeDataTypes(*child); + if (image_domain_node.is_type<language::vector_type>()) { + ASTNodeDataType image_type = getVectorDataType(image_domain_node); + if (image_type.dimension() != nb_image_expressions) { + std::ostringstream message; + message << "note: expecting " << image_domain_node.m_data_type.dimension() << " expressions found " + << nb_image_expressions << std::ends; + throw parse_error(message.str(), image_domain_node.begin()); + } + } else { + std::ostringstream message; + message << "note: number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow + << image_domain_node.string() << rang::style::reset << rang::style::bold + << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow + << image_expression_node.string() << rang::style::reset << std::ends; + throw parse_error(message.str(), image_domain_node.begin()); + } } auto check_image_type = [&](const ASTNode& image_node) { @@ -343,8 +368,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1]; - Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t); - ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; if (image_domain_node.is_type<language::type_expression>()) { data_type = image_domain_node.m_data_type; diff --git a/src/language/ASTNodeExpressionBuilder.cpp b/src/language/ASTNodeExpressionBuilder.cpp index 238d61bd0..eac81ff86 100644 --- a/src/language/ASTNodeExpressionBuilder.cpp +++ b/src/language/ASTNodeExpressionBuilder.cpp @@ -19,6 +19,7 @@ #include <node_processor/LocalNameProcessor.hpp> #include <node_processor/NameProcessor.hpp> #include <node_processor/OStreamProcessor.hpp> +#include <node_processor/TupleToVectorProcessor.hpp> #include <node_processor/ValueProcessor.hpp> #include <node_processor/WhileProcessor.hpp> @@ -38,6 +39,24 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n) ASTNodeAffectationExpressionBuilder{n}; } + } else if (n.is_type<language::tuple_expression>()) { + switch (n.children.size()) { + case 1: { + n.m_node_processor = std::make_unique<TupleToVectorProcessor<1>>(n); + break; + } + case 2: { + n.m_node_processor = std::make_unique<TupleToVectorProcessor<2>>(n); + break; + } + case 3: { + n.m_node_processor = std::make_unique<TupleToVectorProcessor<3>>(n); + break; + } + default: { + throw parse_error("unexpected error: invalid tuple size", n.begin()); + } + } } else if (n.is_type<language::function_definition>()) { n.m_node_processor = std::make_unique<FakeProcessor>(); diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp index 34a1b8571..c6d7e3330 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp @@ -336,11 +336,18 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node ASTNode& function_expression = *function_descriptor.definitionNode().children[1]; if (function_expression.is_type<language::expression_list>()) { - Assert(function_image_domain.is_type<language::type_expression>()); - ASTNode& image_domain_node = function_image_domain; + if (function_image_domain.is_type<language::vector_type>()) { + for (size_t i = 0; i < function_expression.children.size(); ++i) { + function_processor->addFunctionExpressionProcessor( + this->_getFunctionProcessor(function_expression.children[i]->m_data_type, ASTNodeDataType::double_t, node, + *function_expression.children[i])); + } + } else { + ASTNode& image_domain_node = function_image_domain; - for (size_t i = 0; i < function_expression.children.size(); ++i) { - add_component_expression(*function_expression.children[i], *image_domain_node.children[i]); + for (size_t i = 0; i < function_expression.children.size(); ++i) { + add_component_expression(*function_expression.children[i], *image_domain_node.children[i]); + } } } else { add_component_expression(function_expression, function_image_domain); diff --git a/src/language/PEGGrammar.hpp b/src/language/PEGGrammar.hpp index 0f216d9aa..d7506b5be 100644 --- a/src/language/PEGGrammar.hpp +++ b/src/language/PEGGrammar.hpp @@ -213,7 +213,9 @@ struct logical_or : list_must< logical_and, or_op >{}; struct expression : logical_or {}; -struct expression_list : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{}; +struct tuple_expression : seq< open_parent, expression, plus< if_must< COMMA, expression > >, close_parent >{}; + +struct expression_list : seq< open_parent, sor< tuple_expression, expression >, plus< if_must< COMMA, sor< tuple_expression, expression > > >, close_parent >{}; struct affect_op : sor< eq_op, multiplyeq_op, divideeq_op, pluseq_op, minuseq_op > {}; diff --git a/src/language/node_processor/TupleToVectorProcessor.hpp b/src/language/node_processor/TupleToVectorProcessor.hpp new file mode 100644 index 000000000..33f6bc305 --- /dev/null +++ b/src/language/node_processor/TupleToVectorProcessor.hpp @@ -0,0 +1,45 @@ +#ifndef TUPLE_TO_VECTOR_PROCESSOR_HPP +#define TUPLE_TO_VECTOR_PROCESSOR_HPP + +#include <node_processor/INodeProcessor.hpp> + +#include <node_processor/ASTNodeExpressionListProcessor.hpp> + +template <size_t N> +class TupleToVectorProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + + ASTNodeExpressionListProcessor m_list_processor; + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + AggregateDataVariant v = std::get<AggregateDataVariant>(m_list_processor.execute(exec_policy)); + + Assert(v.size() == N); + + TinyVector<N> x; + + for (size_t i = 0; i < N; ++i) { + std::visit( + [&](auto&& v) { + using ValueT = std::decay_t<decltype(v)>; + if constexpr (std::is_arithmetic_v<ValueT>) { + x[i] = v; + } else { + Assert(false, "unexpected value type"); + } + }, + v[i]); + } + + return DataVariant{std::move(x)}; + } + + TupleToVectorProcessor(ASTNode& node) : m_node{node}, m_list_processor{node} {} +}; + +#endif // TUPLE_TO_VECTOR_PROCESSOR_HPP -- GitLab