From 6b2407fceaea1b3bd36ff6e6b8681e8a90bc5657 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Mon, 27 Jan 2020 19:06:59 +0100 Subject: [PATCH] Add vector treatment in functions (arguments and returned types) --- src/language/ASTNodeDataType.cpp | 17 +++++ src/language/ASTNodeDataType.hpp | 2 + src/language/ASTNodeDataTypeBuilder.cpp | 26 +++---- src/language/ASTNodeDataTypeFlattener.cpp | 2 + .../ASTNodeFunctionExpressionBuilder.cpp | 74 +++++++++++++++++++ ...STNodeListAffectationExpressionBuilder.cpp | 30 ++++++++ 6 files changed, 137 insertions(+), 14 deletions(-) diff --git a/src/language/ASTNodeDataType.cpp b/src/language/ASTNodeDataType.cpp index 46490c580..c0d118711 100644 --- a/src/language/ASTNodeDataType.cpp +++ b/src/language/ASTNodeDataType.cpp @@ -1,5 +1,22 @@ +#include <ASTNode.hpp> #include <ASTNodeDataType.hpp> +#include <PEGGrammar.hpp> + +ASTNodeDataType +getVectorDataType(const ASTNode& type_node) +{ + if (not(type_node.is_type<language::vector_type>() and (type_node.children.size() == 2))) { + throw parse_error("unexpected node type", type_node.begin()); + } + ASTNode& dimension_node = *type_node.children[1]; + if (not dimension_node.is_type<language::integer>()) { + throw parse_error("unexpected non integer constant dimension", dimension_node.begin()); + } + const size_t dimension = std::stol(dimension_node.string()); + return ASTNodeDataType{ASTNodeDataType::vector_t, dimension}; +} + std::string dataTypeName(const ASTNodeDataType& data_type) { diff --git a/src/language/ASTNodeDataType.hpp b/src/language/ASTNodeDataType.hpp index 79ae74453..2d220a4d9 100644 --- a/src/language/ASTNodeDataType.hpp +++ b/src/language/ASTNodeDataType.hpp @@ -52,6 +52,8 @@ class ASTNodeDataType ~ASTNodeDataType() = default; }; +ASTNodeDataType getVectorDataType(const ASTNode& type_node); + std::string dataTypeName(const ASTNodeDataType& data_type); ASTNodeDataType dataTypePromotion(const ASTNodeDataType& data_type_1, const ASTNodeDataType& data_type_2); diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp index b7e349374..c1ad69879 100644 --- a/src/language/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ASTNodeDataTypeBuilder.cpp @@ -38,12 +38,7 @@ ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNo } else if (type_node.is_type<language::R_set>()) { data_type = ASTNodeDataType::double_t; } else if (type_node.is_type<language::vector_type>()) { - ASTNode& dimension_node = *type_node.children[1]; - if (not dimension_node.is_type<language::integer>()) { - throw parse_error("unexpected non integer constant dimension", dimension_node.begin()); - } - const size_t dimension = std::stol(dimension_node.string()); - data_type = ASTNodeDataType{ASTNodeDataType::vector_t, dimension}; + data_type = getVectorDataType(type_node); } else if (type_node.is_type<language::string_type>()) { data_type = ASTNodeDataType::string_t; } @@ -98,7 +93,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const } else if (n.is_type<language::integer>()) { n.m_data_type = ASTNodeDataType::int_t; } else if (n.is_type<language::vector_type>()) { - n.m_data_type = ASTNodeDataType::vector_t; + n.m_data_type = getVectorDataType(n); } else if (n.is_type<language::literal>()) { n.m_data_type = ASTNodeDataType::string_t; } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) { @@ -134,7 +129,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const } const size_t nb_parameter_domains = - (parameters_domain_node.children.size() > 0) ? parameters_domain_node.children.size() : 1; + (parameters_domain_node.is_type<language::type_expression>()) ? parameters_domain_node.children.size() : 1; const size_t nb_parameter_names = (parameters_name_node.children.size() > 0) ? parameters_name_node.children.size() : 1; @@ -159,7 +154,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const } else if (type_node.is_type<language::R_set>()) { data_type = ASTNodeDataType::double_t; } else if (type_node.is_type<language::vector_type>()) { - data_type = ASTNodeDataType::vector_t; + data_type = getVectorDataType(type_node); } else if (type_node.is_type<language::string_type>()) { data_type = ASTNodeDataType::string_t; } @@ -180,10 +175,10 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const i_symbol->attributes().setDataType(data_type); }; - if (parameters_domain_node.children.size() == 0) { + if (nb_parameter_domains == 1) { simple_type_allocator(parameters_domain_node, parameters_name_node); } else { - for (size_t i = 0; i < parameters_domain_node.children.size(); ++i) { + for (size_t i = 0; i < nb_parameter_domains; ++i) { simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]); } } @@ -229,7 +224,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const } else if (image_node.is_type<language::R_set>()) { value_type = ASTNodeDataType::double_t; } else if (image_node.is_type<language::vector_type>()) { - value_type = ASTNodeDataType::vector_t; + value_type = getVectorDataType(image_node); } else if (image_node.is_type<language::string_type>()) { value_type = ASTNodeDataType::string_t; } @@ -351,7 +346,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t); ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; - if (image_domain_node.children.size() > 0) { + if (image_domain_node.is_type<language::type_expression>()) { data_type = image_domain_node.m_data_type; } else { if (image_domain_node.is_type<language::B_set>()) { @@ -362,6 +357,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const data_type = ASTNodeDataType::unsigned_int_t; } else if (image_domain_node.is_type<language::R_set>()) { data_type = ASTNodeDataType::double_t; + } else if (image_domain_node.is_type<language::vector_type>()) { + data_type = getVectorDataType(image_domain_node); } else if (image_domain_node.is_type<language::string_type>()) { data_type = ASTNodeDataType::string_t; } @@ -390,7 +387,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const throw parse_error(message.str(), n.begin()); } } else if (n.is_type<language::B_set>() or n.is_type<language::Z_set>() or n.is_type<language::N_set>() or - n.is_type<language::R_set>() or n.is_type<language::string_type>()) { + n.is_type<language::R_set>() or n.is_type<language::string_type>() or + n.is_type<language::vector_type>()) { n.m_data_type = ASTNodeDataType::typename_t; } else if (n.is_type<language::name_list>() or n.is_type<language::function_argument_list>()) { n.m_data_type = ASTNodeDataType::void_t; diff --git a/src/language/ASTNodeDataTypeFlattener.cpp b/src/language/ASTNodeDataTypeFlattener.cpp index 0545aa25c..1f1002f80 100644 --- a/src/language/ASTNodeDataTypeFlattener.cpp +++ b/src/language/ASTNodeDataTypeFlattener.cpp @@ -39,6 +39,8 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT data_type = ASTNodeDataType::unsigned_int_t; } else if (image_sub_domain->is_type<language::R_set>()) { data_type = ASTNodeDataType::double_t; + } else if (image_sub_domain->is_type<language::vector_type>()) { + data_type = getVectorDataType(*image_sub_domain); } else if (image_sub_domain->is_type<language::string_type>()) { data_type = ASTNodeDataType::string_t; } diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp index 8b3787fd5..34a1b8571 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp @@ -43,6 +43,27 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy } }; + auto get_function_argument_converter_for_vector = + [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { + using ParameterT = std::decay_t<decltype(parameter_v)>; + switch (node_sub_data_type.m_data_type) { + case ASTNodeDataType::vector_t: { + if (node_sub_data_type.m_data_type.dimension() == parameter_v.dimension()) { + return std::make_unique<FunctionArgumentConverter<ParameterT, ParameterT>>(parameter_id); + } else { + throw parse_error("invalid argument dimension (expected " + std::to_string(parameter_v.dimension()) + + ", provided " + std::to_string(node_sub_data_type.m_data_type.dimension()) + ")", + std::vector{node_sub_data_type.m_parent_node.begin()}); + } + } + // LCOV_EXCL_START + default: { + throw parse_error("invalid argument type", std::vector{node_sub_data_type.m_parent_node.begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_argument_converter_for_string = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { switch (node_sub_data_type.m_data_type) { case ASTNodeDataType::bool_t: { @@ -85,6 +106,22 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy case ASTNodeDataType::string_t: { return get_function_argument_converter_for_string(); } + case ASTNodeDataType::vector_t: { + switch (parameter_symbol.attributes().dataType().dimension()) { + case 1: { + return get_function_argument_converter_for_vector(TinyVector<1>{}); + } + case 2: { + return get_function_argument_converter_for_vector(TinyVector<2>{}); + } + case 3: { + return get_function_argument_converter_for_vector(TinyVector<3>{}); + } + default: { + throw parse_error("unexpected error: invalid parameter dimension", std::vector{m_node.begin()}); + } + } + } // LCOV_EXCL_START default: { @@ -194,6 +231,25 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType ex } }; + auto get_function_processor_for_expression_vector = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> { + using ReturnT = std::decay_t<decltype(return_v)>; + switch (expression_value_type) { + case ASTNodeDataType::vector_t: { + if (expression_value_type.dimension() == return_v.dimension()) { + return std::make_unique<FunctionExpressionProcessor<ReturnT, ReturnT>>(function_component_expression); + } else { + throw parse_error("invalid dimension for returned vector", std::vector{function_component_expression.begin()}); + } + } + // LCOV_EXCL_START + default: { + throw parse_error("unexpected error: undefined expression value type for function", + std::vector{node.children[1]->begin()}); + } + // LCOV_EXCL_STOP + } + }; + auto get_function_processor_for_value = [&]() { switch (return_value_type) { case ASTNodeDataType::bool_t: { @@ -208,6 +264,22 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType ex case ASTNodeDataType::double_t: { return get_function_processor_for_expression_value(double{}); } + case ASTNodeDataType::vector_t: { + switch (return_value_type.dimension()) { + case 1: { + return get_function_processor_for_expression_vector(TinyVector<1>{}); + } + case 2: { + return get_function_processor_for_expression_vector(TinyVector<2>{}); + } + case 3: { + return get_function_processor_for_expression_vector(TinyVector<3>{}); + } + default: { + throw parse_error("unexpected error: invalid dimension in returned type", std::vector{node.begin()}); + } + } + } case ASTNodeDataType::string_t: { return get_function_processor_for_expression_value(std::string{}); } @@ -248,6 +320,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node return_value_type = ASTNodeDataType::unsigned_int_t; } else if (image_domain_node.is_type<language::R_set>()) { return_value_type = ASTNodeDataType::double_t; + } else if (image_domain_node.is_type<language::vector_type>()) { + return_value_type = getVectorDataType(image_domain_node); } else if (image_domain_node.is_type<language::string_type>()) { return_value_type = ASTNodeDataType::string_t; } diff --git a/src/language/ASTNodeListAffectationExpressionBuilder.cpp b/src/language/ASTNodeListAffectationExpressionBuilder.cpp index a0bbeb88a..a39f53c6a 100644 --- a/src/language/ASTNodeListAffectationExpressionBuilder.cpp +++ b/src/language/ASTNodeListAffectationExpressionBuilder.cpp @@ -41,6 +41,16 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( } }; + auto add_affectation_processor_for_vector_data = [&](const auto& value, + const ASTNodeSubDataType& node_sub_data_type) { + using ValueT = std::decay_t<decltype(value)>; + if (node_sub_data_type.m_data_type.dimension() == value.dimension()) { + list_affectation_processor->template add<ValueT, ValueT>(value_node); + } else { + throw parse_error("invalid dimension", std::vector{node_sub_data_type.m_parent_node.begin()}); + } + }; + auto add_affectation_processor_for_string_data = [&](const ASTNodeSubDataType& node_sub_data_type) { if constexpr (std::is_same_v<OperatorT, language::eq_op> or std::is_same_v<OperatorT, language::pluseq_op>) { switch (node_sub_data_type.m_data_type) { @@ -98,6 +108,26 @@ ASTNodeListAffectationExpressionBuilder::_buildAffectationProcessor( add_affectation_processor_for_data(double{}, node_sub_data_type); break; } + case ASTNodeDataType::vector_t: { + switch (value_type.dimension()) { + case 1: { + add_affectation_processor_for_vector_data(TinyVector<1>{}, node_sub_data_type); + break; + } + case 2: { + add_affectation_processor_for_vector_data(TinyVector<2>{}, node_sub_data_type); + break; + } + case 3: { + add_affectation_processor_for_vector_data(TinyVector<3>{}, node_sub_data_type); + break; + } + default: { + throw parse_error("invalid dimension", std::vector{value_node.begin()}); + } + } + break; + } case ASTNodeDataType::string_t: { add_affectation_processor_for_string_data(node_sub_data_type); break; -- GitLab