From 1f9ac713198c035a7c158940cb14a6bd5878601a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Fri, 27 Nov 2020 10:59:55 +0100 Subject: [PATCH] Use natural conversion checkers for R^dxd functions - deal with arguments as well as return value - add few missing tests --- .../ast/ASTNodeFunctionExpressionBuilder.cpp | 23 +- .../test_ASTNodeFunctionExpressionBuilder.cpp | 298 +++++++++++++++++- 2 files changed, 302 insertions(+), 19 deletions(-) diff --git a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp index 0d00f1c82..6484b85b0 100644 --- a/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeFunctionExpressionBuilder.cpp @@ -467,11 +467,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node const ASTNodeDataType return_value_type = image_domain_node.m_data_type.contentType(); - if ((return_value_type == ASTNodeDataType::vector_t) and (return_value_type.dimension() == 1)) { - ASTNodeNaturalConversionChecker{expression_node, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{expression_node, return_value_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{expression_node, return_value_type}; function_processor->addFunctionExpressionProcessor( this->_getFunctionProcessor(return_value_type, node, expression_node)); @@ -486,11 +482,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node if (function_image_domain.is_type<language::vector_type>()) { ASTNodeDataType vector_type = getVectorDataType(function_image_domain); - if ((vector_type.dimension() == 1) and (function_expression.m_data_type != ASTNodeDataType::vector_t)) { - ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{function_expression, vector_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, vector_type}; + if (function_expression.is_type<language::expression_list>()) { Assert(vector_type.dimension() == function_expression.children.size()); @@ -556,12 +549,8 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node } else if (function_image_domain.is_type<language::matrix_type>()) { ASTNodeDataType matrix_type = getMatrixDataType(function_image_domain); - if ((matrix_type.nbRows() == 1) and (matrix_type.nbColumns() == 1) and - (function_expression.m_data_type != ASTNodeDataType::vector_t)) { - ASTNodeNaturalConversionChecker{function_expression, ASTNodeDataType::build<ASTNodeDataType::double_t>()}; - } else { - ASTNodeNaturalConversionChecker{function_expression, matrix_type}; - } + ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{function_expression, matrix_type}; + if (function_expression.is_type<language::expression_list>()) { Assert(matrix_type.nbRows() * matrix_type.nbColumns() == function_expression.children.size()); @@ -584,7 +573,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node } // LCOV_EXCL_START default: { - throw ParseError("unexpected error: invalid vector_t dimension", std::vector{node.begin()}); + throw ParseError("unexpected error: invalid matrix_t dimensions", std::vector{node.begin()}); } // LCOV_EXCL_STOP } diff --git a/tests/test_ASTNodeFunctionExpressionBuilder.cpp b/tests/test_ASTNodeFunctionExpressionBuilder.cpp index 75ba4decb..fd0163b7a 100644 --- a/tests/test_ASTNodeFunctionExpressionBuilder.cpp +++ b/tests/test_ASTNodeFunctionExpressionBuilder.cpp @@ -397,6 +397,60 @@ f(x); CHECK_AST(data, result); } + SECTION("Return R^1x1 -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x+x; +let x : R^1x1, x = 1; +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^2x2 -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x+x; +let x : R^2x2, x = (1,2,3,4); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return R^3x3 -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x+x; +let x : R^3x3, x = (1,2,3,4,5,6,7,8,9); +f(x); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::name:x:NameProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return scalar -> R^1") { std::string_view data = R"( @@ -453,6 +507,73 @@ f(1,2,3); CHECK_AST(data, result); } + SECTION("Return scalar -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> x+1; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^2x2") + { + std::string_view data = R"( +let f : R*R*R*R -> R^2x2, (x,y,z,t) -> (x,y,z,t); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 2ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return tuple -> R^3x3") + { + std::string_view data = R"( +let f : R^3*R^3*R^3 -> R^3x3, (x,y,z) -> (x[0],x[1],x[2],y[0],y[1],y[2],z[0],z[1],z[2]); +f((1,2,3),(4,5,6),(7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:TupleToTinyMatrixProcessor<FunctionProcessor, 3ul>) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:1:ValueProcessor) + | +-(language::integer:2:ValueProcessor) + | `-(language::integer:3:ValueProcessor) + +-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + | +-(language::integer:4:ValueProcessor) + | +-(language::integer:5:ValueProcessor) + | `-(language::integer:6:ValueProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return '0' -> R^1") { std::string_view data = R"( @@ -504,6 +625,57 @@ f(1); CHECK_AST(data, result); } + SECTION("Return '0' -> R^1x1") + { + std::string_view data = R"( +let f : R -> R^1x1, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<1ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^2x2") + { + std::string_view data = R"( +let f : R -> R^2x2, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<2ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Return '0' -> R^3x3") + { + std::string_view data = R"( +let f : R -> R^3x3, x -> 0; +f(1); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionExpressionProcessor<TinyMatrix<3ul, double>, ZeroType>) + +-(language::name:f:NameProcessor) + `-(language::integer:1:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound") { std::string_view data = R"( @@ -525,6 +697,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, (x), (x,y,z,t), (x,y,z, x,x,x, t,t,t)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Return embedded R^d compound with '0'") { std::string_view data = R"( @@ -546,6 +739,27 @@ f(1,2,3,4); CHECK_AST(data, result); } + SECTION("Return embedded R^dxd compound with '0'") + { + std::string_view data = R"( +let f : R*R*R*R -> R*R^1x1*R^2x2*R^3x3, (x,y,z,t) -> (t, 0, 0, (x, y, z, t, x, y, z, t, x)); +f(1,2,3,4); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::function_argument_list:ASTNodeExpressionListProcessor) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + `-(language::integer:4:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments '0' -> R^1") { std::string_view data = R"( @@ -597,6 +811,57 @@ f(0); CHECK_AST(data, result); } + SECTION("Arguments '0' -> R^1x1") + { + std::string_view data = R"( +let f : R^1x1 -> R^1x1, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^2x2") + { + std::string_view data = R"( +let f : R^2x2 -> R^2x2, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + + SECTION("Arguments '0' -> R^3x3") + { + std::string_view data = R"( +let f : R^3x3 -> R^3x3, x -> x; +f(0); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::integer:0:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments tuple -> R^d") { std::string_view data = R"( @@ -617,11 +882,37 @@ f((1,2,3)); CHECK_AST(data, result); } + SECTION("Arguments tuple -> R^dxd") + { + std::string_view data = R"( +let f: R^3x3 -> R, x -> x[0,0]+x[0,1]+x[0,2]; +f((1,2,3,4,5,6,7,8,9)); +)"; + + std::string_view result = R"( +(root:ASTNodeListProcessor) + `-(language::function_evaluation:FunctionProcessor) + +-(language::name:f:NameProcessor) + `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) + +-(language::integer:1:ValueProcessor) + +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::integer:4:ValueProcessor) + +-(language::integer:5:ValueProcessor) + +-(language::integer:6:ValueProcessor) + +-(language::integer:7:ValueProcessor) + +-(language::integer:8:ValueProcessor) + `-(language::integer:9:ValueProcessor) +)"; + + CHECK_AST(data, result); + } + SECTION("Arguments compound with tuple") { std::string_view data = R"( -let f: R*R^3*R^2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0]+y[1]; -f(2,(1,2,3),(2,1.3)); +let f: R*R^3*R^2x2->R, (t,x,y) -> t*(x[0]+x[1]+x[2])*y[0,0]+y[1,1]; +f(2,(1,2,3),(2,3,-1,1.3)); )"; std::string_view result = R"( @@ -636,6 +927,9 @@ f(2,(1,2,3),(2,1.3)); | `-(language::integer:3:ValueProcessor) `-(language::tuple_expression:TupleToVectorProcessor<ASTNodeExpressionListProcessor>) +-(language::integer:2:ValueProcessor) + +-(language::integer:3:ValueProcessor) + +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>) + | `-(language::integer:1:ValueProcessor) `-(language::real:1.3:ValueProcessor) )"; -- GitLab