From c841b1d89c42a114d40e878b877a369d6dd151ef Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Mon, 28 Oct 2019 14:45:42 +0100
Subject: [PATCH] Check return datatype at function declaration

Also add a bunch of tests for function declaration and evaluation
---
 src/language/ASTNodeDataTypeBuilder.cpp |  30 ++++-
 src/language/ASTNodeValueBuilder.cpp    |   2 +-
 tests/test_ASTNodeDataTypeBuilder.cpp   | 158 ++++++++++++++++++++++++
 3 files changed, 185 insertions(+), 5 deletions(-)

diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp
index a96d31461..1857b73ab 100644
--- a/src/language/ASTNodeDataTypeBuilder.cpp
+++ b/src/language/ASTNodeDataTypeBuilder.cpp
@@ -146,6 +146,31 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n)
           this->_buildNodeDataTypes(*child);
         }
 
+        auto check_image_type = [&](const ASTNode& image_node) {
+          ASTNodeDataType value_type{ASTNodeDataType::undefined_t};
+          if (image_node.is_type<language::B_set>()) {
+            value_type = ASTNodeDataType::bool_t;
+          } else if (image_node.is_type<language::Z_set>()) {
+            value_type = ASTNodeDataType::int_t;
+          } else if (image_node.is_type<language::N_set>()) {
+            value_type = ASTNodeDataType::unsigned_int_t;
+          } else if (image_node.is_type<language::R_set>()) {
+            value_type = ASTNodeDataType::double_t;
+          }
+
+          if (value_type == ASTNodeDataType::undefined_t) {
+            throw parse_error("invalid value type", image_node.begin());
+          }
+        };
+
+        if (image_domain_node.children.size() == 0) {
+          check_image_type(image_domain_node);
+        } else {
+          for (size_t i = 0; i < image_domain_node.children.size(); ++i) {
+            check_image_type(*image_domain_node.children[i]);
+          }
+        }
+
         n.m_data_type = ASTNodeDataType::void_t;
       } else if (n.is_type<language::name>()) {
         std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;
@@ -245,7 +270,6 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n)
 
         ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
         if (image_domain_node.children.size() > 0) {
-          std::ostringstream message;
           throw parse_error("compound data type is not implemented yet", image_domain_node.begin());
         } else {
           if (image_domain_node.is_type<language::B_set>()) {
@@ -259,9 +283,7 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n)
           }
         }
 
-        if (data_type == ASTNodeDataType::undefined_t) {
-          throw parse_error("invalid return type", image_domain_node.begin());
-        }
+        Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE
 
         n.m_data_type = data_type;
       } else if (n.children[0]->m_data_type == ASTNodeDataType::c_function_t) {
diff --git a/src/language/ASTNodeValueBuilder.cpp b/src/language/ASTNodeValueBuilder.cpp
index 65d014960..d0320c4ea 100644
--- a/src/language/ASTNodeValueBuilder.cpp
+++ b/src/language/ASTNodeValueBuilder.cpp
@@ -63,5 +63,5 @@ ASTNodeValueBuilder::ASTNodeValueBuilder(ASTNode& node)
     this->_buildNodeValue(function_expression);
   }
 
-  std::cout << " - build node data types\n";
+  std::cout << " - build node values\n";
 }
diff --git a/tests/test_ASTNodeDataTypeBuilder.cpp b/tests/test_ASTNodeDataTypeBuilder.cpp
index 60214915d..45595e0b2 100644
--- a/tests/test_ASTNodeDataTypeBuilder.cpp
+++ b/tests/test_ASTNodeDataTypeBuilder.cpp
@@ -304,6 +304,164 @@ let f : N*string -> N, (i,s) -> i;
 
         REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error);
       }
+
+      SECTION("invalid return type")
+      {
+        std::string_view data = R"(
+let f : R -> string, x -> "foo";
+)";
+        string_input input{data, "test.pgs"};
+        auto ast = ASTBuilder::build(input);
+        ASTSymbolTableBuilder{*ast};
+
+        REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error);
+      }
+
+      SECTION("invalid return type 2")
+      {
+        std::string_view data = R"(
+let f : R -> N*string, x -> (2,"foo");
+)";
+        string_input input{data, "test.pgs"};
+        auto ast = ASTBuilder::build(input);
+        ASTSymbolTableBuilder{*ast};
+
+        REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error);
+      }
+    }
+  }
+
+  SECTION("function evaluation")
+  {
+    SECTION("R-functions")
+    {
+      SECTION("single argument")
+      {
+        std::string_view data = R"(
+let incr : R -> R, x -> x+1;
+R x = incr(3);
+)";
+
+        std::string_view result = R"(
+(root:void)
+ +-(language::let_declaration:void)
+ |   `-(language::name:incr:function)
+ `-(language::declaration:R)
+     +-(language::R_set:typename)
+     +-(language::name:x:R)
+     `-(language::function_evaluation:R)
+         +-(language::name:incr:function)
+         `-(language::integer:3:Z)
+)";
+
+        CHECK_AST(data, result);
+      }
+
+      SECTION("multiple variable")
+      {
+        std::string_view data = R"(
+let substract : R*R -> R, (x,y) -> x-y;
+R diff = substract(3,2);
+)";
+
+        std::string_view result = R"(
+(root:void)
+ +-(language::let_declaration:void)
+ |   `-(language::name:substract:function)
+ `-(language::declaration:R)
+     +-(language::R_set:typename)
+     +-(language::name:diff:R)
+     `-(language::function_evaluation:R)
+         +-(language::name:substract:function)
+         `-(language::function_argument_list:void)
+             +-(language::integer:3:Z)
+             `-(language::integer:2:Z)
+)";
+
+        CHECK_AST(data, result);
+      }
+    }
+
+    SECTION("Z-functions")
+    {
+      std::string_view data = R"(
+let incr : Z -> Z, z -> z+1;
+Z z = incr(3);
+)";
+
+      std::string_view result = R"(
+(root:void)
+ +-(language::let_declaration:void)
+ |   `-(language::name:incr:function)
+ `-(language::declaration:Z)
+     +-(language::Z_set:typename)
+     +-(language::name:z:Z)
+     `-(language::function_evaluation:Z)
+         +-(language::name:incr:function)
+         `-(language::integer:3:Z)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("N-functions")
+    {
+      std::string_view data = R"(
+let double : N -> N, n -> 2*n;
+N n = double(3);
+)";
+
+      std::string_view result = R"(
+(root:void)
+ +-(language::let_declaration:void)
+ |   `-(language::name:double:function)
+ `-(language::declaration:N)
+     +-(language::N_set:typename)
+     +-(language::name:n:N)
+     `-(language::function_evaluation:N)
+         +-(language::name:double:function)
+         `-(language::integer:3:Z)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("B-functions")
+    {
+      std::string_view data = R"(
+let greater_than_2 : R -> B, x -> x>2;
+B b = greater_than_2(3);
+)";
+
+      std::string_view result = R"(
+(root:void)
+ +-(language::let_declaration:void)
+ |   `-(language::name:greater_than_2:function)
+ `-(language::declaration:B)
+     +-(language::B_set:typename)
+     +-(language::name:b:B)
+     `-(language::function_evaluation:B)
+         +-(language::name:greater_than_2:function)
+         `-(language::integer:3:Z)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("errors")
+    {
+      SECTION("not a function")
+      {
+        std::string_view data = R"(
+R not_a_function = 3;
+not_a_function(2,3);
+)";
+        string_input input{data, "test.pgs"};
+        auto ast = ASTBuilder::build(input);
+        ASTSymbolTableBuilder{*ast};
+
+        REQUIRE_THROWS_AS(ASTNodeDataTypeBuilder{*ast}, parse_error);
+      }
     }
   }
 
-- 
GitLab