From 1f9ac713198c035a7c158940cb14a6bd5878601a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Fri, 27 Nov 2020 10:59:55 +0100
Subject: [PATCH] Use natural conversion checkers for R^dxd functions

- deal with arguments as well as return value
- add few missing tests
---
 .../ast/ASTNodeFunctionExpressionBuilder.cpp  |  23 +-
 .../test_ASTNodeFunctionExpressionBuilder.cpp | 298 +++++++++++++++++-
 2 files changed, 302 insertions(+), 19 deletions(-)

diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp
index 0d00f1c82..6484b85b0 100644
--- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp
+++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp
@@ -467,11 +467,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
 
     const ASTNodeDataType return_value_type = image_domain_node.m_data_type.contentType();
 
-    if ((return_value_type == ASTNodeDataType::vector_t) and (return_value_type.dimension() == 1)) {
-      ASTNodeNaturalConversionChecker{expression_node, ASTNodeDataType::build<ASTNodeDataType::double_t>()};
-    } else {
-      ASTNodeNaturalConversionChecker{expression_node, return_value_type};
-    }
+    ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{expression_node, return_value_type};
 
     function_processor->addFunctionExpressionProcessor(
       this->_getFunctionProcessor(return_value_type, node, expression_node));
@@ -486,11 +482,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
   if (function_image_domain.is_type<language::vector_type>()) {
     ASTNodeDataType vector_type = getVectorDataType(function_image_domain);
 
-    if ((vector_type.dimension() == 1) and (function_expression.m_data_type != ASTNodeDataType::vector_t)) {
-      ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()};
-    } else {
-      ASTNodeNaturalConversionChecker{function_expression, vector_type};
-    }
+    ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, vector_type};
+
     if (function_expression.is_type<language::expression_list>()) {
       Assert(vector_type.dimension() == function_expression.children.size());
 
@@ -556,12 +549,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
   } else if (function_image_domain.is_type<language::matrix_type>()) {
     ASTNodeDataType matrix_type = getMatrixDataType(function_image_domain);
 
-    if ((matrix_type.nbRows() == 1) and (matrix_type.nbColumns() == 1) and
-        (function_expression.m_data_type != ASTNodeDataType::vector_t)) {
-      ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()};
-    } else {
-      ASTNodeNaturalConversionChecker{function_expression, matrix_type};
-    }
+    ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, matrix_type};
+
     if (function_expression.is_type<language::expression_list>()) {
       Assert(matrix_type.nbRows() * matrix_type.nbColumns() == function_expression.children.size());
 
@@ -584,7 +573,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
       }
         // LCOV_EXCL_START
       default: {
-        throw ParseError("unexpected error: invalid vector_t dimension", std::vector{node.begin()});
+        throw ParseError("unexpected error: invalid matrix_t dimensions", std::vector{node.begin()});
       }
         // LCOV_EXCL_STOP
       }
diff --git a/tests/test_ASTNodeFunctionExpressionBuilder.cpp b/tests/test_ASTNodeFunctionExpressionBuilder.cpp
index 75ba4decb..fd0163b7a 100644
--- a/tests/test_ASTNodeFunctionExpressionBuilder.cpp
+++ b/tests/test_ASTNodeFunctionExpressionBuilder.cpp
@@ -397,6 +397,60 @@ f(x);
       CHECK_AST(data, result);
     }
 
+    SECTION("Return R^1x1 -> R^1x1")
+    {
+      std::string_view data = R"(
+let f : R^1x1 -> R^1x1, x -> x+x;
+let x : R^1x1, x = 1;
+f(x);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::name:x:NameProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return R^2x2 -> R^2x2")
+    {
+      std::string_view data = R"(
+let f : R^2x2 -> R^2x2, x -> x+x;
+let x : R^2x2, x = (1,2,3,4);
+f(x);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::name:x:NameProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return R^3x3 -> R^3x3")
+    {
+      std::string_view data = R"(
+let f : R^3x3 -> R^3x3, x -> x+x;
+let x : R^3x3, x = (1,2,3,4,5,6,7,8,9);
+f(x);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::name:x:NameProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Return scalar -> R^1")
     {
       std::string_view data = R"(
@@ -453,6 +507,73 @@ f(1,2,3);
       CHECK_AST(data, result);
     }
 
+    SECTION("Return scalar -> R^1x1")
+    {
+      std::string_view data = R"(
+let f : R -> R^1x1, x -> x+1;
+f(1);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:1:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return tuple -> R^2x2")
+    {
+      std::string_view data = R"(
+let f : R*R*R*R -> R^2x2, (x,y,z,t) -> (x,y,z,t);
+f(1,2,3,4);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 2ul>)
+     +-(language::name:f:NameProcessor)
+     `-(language::function_argument_list:ASTNodeExpressionListProcessor)
+         +-(language::integer:1:ValueProcessor)
+         +-(language::integer:2:ValueProcessor)
+         +-(language::integer:3:ValueProcessor)
+         `-(language::integer:4:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return tuple -> R^3x3")
+    {
+      std::string_view data = R"(
+let f : R^3*R^3*R^3 -> R^3x3, (x,y,z) -> (x[0],x[1],x[2],y[0],y[1],y[2],z[0],z[1],z[2]);
+f((1,2,3),(4,5,6),(7,8,9));
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 3ul>)
+     +-(language::name:f:NameProcessor)
+     `-(language::function_argument_list:ASTNodeExpressionListProcessor)
+         +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>)
+         |   +-(language::integer:1:ValueProcessor)
+         |   +-(language::integer:2:ValueProcessor)
+         |   `-(language::integer:3:ValueProcessor)
+         +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>)
+         |   +-(language::integer:4:ValueProcessor)
+         |   +-(language::integer:5:ValueProcessor)
+         |   `-(language::integer:6:ValueProcessor)
+         `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>)
+             +-(language::integer:7:ValueProcessor)
+             +-(language::integer:8:ValueProcessor)
+             `-(language::integer:9:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Return '0' -> R^1")
     {
       std::string_view data = R"(
@@ -504,6 +625,57 @@ f(1);
       CHECK_AST(data, result);
     }
 
+    SECTION("Return '0' -> R^1x1")
+    {
+      std::string_view data = R"(
+let f : R -> R^1x1, x -> 0;
+f(1);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<1ul, double>, ZeroType>)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:1:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return '0' -> R^2x2")
+    {
+      std::string_view data = R"(
+let f : R -> R^2x2, x -> 0;
+f(1);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<2ul, double>, ZeroType>)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:1:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Return '0' -> R^3x3")
+    {
+      std::string_view data = R"(
+let f : R -> R^3x3, x -> 0;
+f(1);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<3ul, double>, ZeroType>)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:1:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Return embedded R^d compound")
     {
       std::string_view data = R"(
@@ -525,6 +697,27 @@ f(1,2,3,4);
       CHECK_AST(data, result);
     }
 
+    SECTION("Return embedded R^dxd compound")
+    {
+      std::string_view data = R"(
+let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, (x), (x,y,z,t), (x,y,z, x,x,x, t,t,t));
+f(1,2,3,4);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::function_argument_list:ASTNodeExpressionListProcessor)
+         +-(language::integer:1:ValueProcessor)
+         +-(language::integer:2:ValueProcessor)
+         +-(language::integer:3:ValueProcessor)
+         `-(language::integer:4:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Return embedded R^d compound with '0'")
     {
       std::string_view data = R"(
@@ -546,6 +739,27 @@ f(1,2,3,4);
       CHECK_AST(data, result);
     }
 
+    SECTION("Return embedded R^dxd compound with '0'")
+    {
+      std::string_view data = R"(
+let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, 0, 0, (x, y, z, t, x, y, z, t, x));
+f(1,2,3,4);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::function_argument_list:ASTNodeExpressionListProcessor)
+         +-(language::integer:1:ValueProcessor)
+         +-(language::integer:2:ValueProcessor)
+         +-(language::integer:3:ValueProcessor)
+         `-(language::integer:4:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Arguments '0' -> R^1")
     {
       std::string_view data = R"(
@@ -597,6 +811,57 @@ f(0);
       CHECK_AST(data, result);
     }
 
+    SECTION("Arguments '0' -> R^1x1")
+    {
+      std::string_view data = R"(
+let f : R^1x1 -> R^1x1, x -> x;
+f(0);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:0:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Arguments '0' -> R^2x2")
+    {
+      std::string_view data = R"(
+let f : R^2x2 -> R^2x2, x -> x;
+f(0);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:0:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
+    SECTION("Arguments '0' -> R^3x3")
+    {
+      std::string_view data = R"(
+let f : R^3x3 -> R^3x3, x -> x;
+f(0);
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::integer:0:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Arguments tuple -> R^d")
     {
       std::string_view data = R"(
@@ -617,11 +882,37 @@ f((1,2,3));
       CHECK_AST(data, result);
     }
 
+    SECTION("Arguments tuple -> R^dxd")
+    {
+      std::string_view data = R"(
+let f: R^3x3 -> R, x -> x[0,0]+x[0,1]+x[0,2];
+f((1,2,3,4,5,6,7,8,9));
+)";
+
+      std::string_view result = R"(
+(root:ASTNodeListProcessor)
+ `-(language::function_evaluation:FunctionProcessor)
+     +-(language::name:f:NameProcessor)
+     `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>)
+         +-(language::integer:1:ValueProcessor)
+         +-(language::integer:2:ValueProcessor)
+         +-(language::integer:3:ValueProcessor)
+         +-(language::integer:4:ValueProcessor)
+         +-(language::integer:5:ValueProcessor)
+         +-(language::integer:6:ValueProcessor)
+         +-(language::integer:7:ValueProcessor)
+         +-(language::integer:8:ValueProcessor)
+         `-(language::integer:9:ValueProcessor)
+)";
+
+      CHECK_AST(data, result);
+    }
+
     SECTION("Arguments compound with tuple")
     {
       std::string_view data = R"(
-let f: R*R^3*R^2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0]+y[1];
-f(2,(1,2,3),(2,1.3));
+let f: R*R^3*R^2x2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0,0]+y[1,1];
+f(2,(1,2,3),(2,3,-1,1.3));
 )";
 
       std::string_view result = R"(
@@ -636,6 +927,9 @@ f(2,(1,2,3),(2,1.3));
          |   `-(language::integer:3:ValueProcessor)
          `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>)
              +-(language::integer:2:ValueProcessor)
+             +-(language::integer:3:ValueProcessor)
+             +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>)
+             |   `-(language::integer:1:ValueProcessor)
              `-(language::real:1.3:ValueProcessor)
 )";
 
-- 
GitLab