From c841b1d89c42a114d40e878b877a369d6dd151ef Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Mon, 28 Oct 2019 14:45:42 +0100 Subject: [PATCH] Check return datatype at function declaration Also add a bunch of tests for function declaration and evaluation --- src/language/ASTNodeDataTypeBuilder.cpp | 30 ++++- src/language/ASTNodeValueBuilder.cpp | 2 +- tests/test_ASTNodeDataTypeBuilder.cpp | 158 ++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 5 deletions(-) diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp index a96d31461..1857b73ab 100644 --- a/src/language/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ASTNodeDataTypeBuilder.cpp @@ -146,6 +146,31 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) this->_buildNodeDataTypes(*child); } + auto check_image_type = [&](const ASTNode& image_node) { + ASTNodeDataType value_type{ASTNodeDataType::undefined_t}; + if (image_node.is_type<language::B_set>()) { + value_type = ASTNodeDataType::bool_t; + } else if (image_node.is_type<language::Z_set>()) { + value_type = ASTNodeDataType::int_t; + } else if (image_node.is_type<language::N_set>()) { + value_type = ASTNodeDataType::unsigned_int_t; + } else if (image_node.is_type<language::R_set>()) { + value_type = ASTNodeDataType::double_t; + } + + if (value_type == ASTNodeDataType::undefined_t) { + throw parse_error("invalid value type", image_node.begin()); + } + }; + + if (image_domain_node.children.size() == 0) { + check_image_type(image_domain_node); + } else { + for (size_t i = 0; i < image_domain_node.children.size(); ++i) { + check_image_type(*image_domain_node.children[i]); + } + } + n.m_data_type = ASTNodeDataType::void_t; } else if (n.is_type<language::name>()) { std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table; @@ -245,7 +270,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) ASTNodeDataType data_type{ASTNodeDataType::undefined_t}; if (image_domain_node.children.size() > 0) { - std::ostringstream message; throw parse_error("compound data type is not implemented yet", image_domain_node.begin()); } else { if (image_domain_node.is_type<language::B_set>()) { @@ -259,9 +283,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) } } - if (data_type == ASTNodeDataType::undefined_t) { - throw parse_error("invalid return type", image_domain_node.begin()); - } + Assert(data_type != ASTNodeDataType::undefined_t); // LCOV_EXCL_LINE n.m_data_type = data_type; } else if (n.children[0]->m_data_type == ASTNodeDataType::c_function_t) { diff --git a/src/language/ASTNodeValueBuilder.cpp b/src/language/ASTNodeValueBuilder.cpp index 65d014960..d0320c4ea 100644 --- a/src/language/ASTNodeValueBuilder.cpp +++ b/src/language/ASTNodeValueBuilder.cpp @@ -63,5 +63,5 @@ ASTNodeValueBuilder::ASTNodeValueBuilder(ASTNode& node) this->_buildNodeValue(function_expression); } - std::cout << " - build node data types\n"; + std::cout << " - build node values\n"; } diff --git a/tests/test_ASTNodeDataTypeBuilder.cpp b/tests/test_ASTNodeDataTypeBuilder.cpp index 60214915d..45595e0b2 100644 --- a/tests/test_ASTNodeDataTypeBuilder.cpp +++ b/tests/test_ASTNodeDataTypeBuilder.cpp @@ -304,6 +304,164 @@ let f : N*string -> N, (i,s) -> i; REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error); } + + SECTION("invalid return type") + { + std::string_view data = R"( +let f : R -> string, x -> "foo"; +)"; + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error); + } + + SECTION("invalid return type 2") + { + std::string_view data = R"( +let f : R -> N*string, x -> (2,"foo"); +)"; + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error); + } + } + } + + SECTION("function evaluation") + { + SECTION("R-functions") + { + SECTION("single argument") + { + std::string_view data = R"( +let incr : R -> R, x -> x+1; +R x = incr(3); +)"; + + std::string_view result = R"( +(root:void) + +-(language::let_declaration:void) + | `-(language::name:incr:function) + `-(language::declaration:R) + +-(language::R_set:typename) + +-(language::name:x:R) + `-(language::function_evaluation:R) + +-(language::name:incr:function) + `-(language::integer:3:Z) +)"; + + CHECK_AST(data, result); + } + + SECTION("multiple variable") + { + std::string_view data = R"( +let substract : R*R -> R, (x,y) -> x-y; +R diff = substract(3,2); +)"; + + std::string_view result = R"( +(root:void) + +-(language::let_declaration:void) + | `-(language::name:substract:function) + `-(language::declaration:R) + +-(language::R_set:typename) + +-(language::name:diff:R) + `-(language::function_evaluation:R) + +-(language::name:substract:function) + `-(language::function_argument_list:void) + +-(language::integer:3:Z) + `-(language::integer:2:Z) +)"; + + CHECK_AST(data, result); + } + } + + SECTION("Z-functions") + { + std::string_view data = R"( +let incr : Z -> Z, z -> z+1; +Z z = incr(3); +)"; + + std::string_view result = R"( +(root:void) + +-(language::let_declaration:void) + | `-(language::name:incr:function) + `-(language::declaration:Z) + +-(language::Z_set:typename) + +-(language::name:z:Z) + `-(language::function_evaluation:Z) + +-(language::name:incr:function) + `-(language::integer:3:Z) +)"; + + CHECK_AST(data, result); + } + + SECTION("N-functions") + { + std::string_view data = R"( +let double : N -> N, n -> 2*n; +N n = double(3); +)"; + + std::string_view result = R"( +(root:void) + +-(language::let_declaration:void) + | `-(language::name:double:function) + `-(language::declaration:N) + +-(language::N_set:typename) + +-(language::name:n:N) + `-(language::function_evaluation:N) + +-(language::name:double:function) + `-(language::integer:3:Z) +)"; + + CHECK_AST(data, result); + } + + SECTION("B-functions") + { + std::string_view data = R"( +let greater_than_2 : R -> B, x -> x>2; +B b = greater_than_2(3); +)"; + + std::string_view result = R"( +(root:void) + +-(language::let_declaration:void) + | `-(language::name:greater_than_2:function) + `-(language::declaration:B) + +-(language::B_set:typename) + +-(language::name:b:B) + `-(language::function_evaluation:B) + +-(language::name:greater_than_2:function) + `-(language::integer:3:Z) +)"; + + CHECK_AST(data, result); + } + + SECTION("errors") + { + SECTION("not a function") + { + std::string_view data = R"( +R not_a_function = 3; +not_a_function(2,3); +)"; + string_input input{data, "test.pgs"}; + auto ast = ASTBuilder::build(input); + ASTSymbolTableBuilder{*ast}; + + REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error); + } } } -- GitLab