Skip to content
Snippets Groups Projects
Commit db24c19b authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Merge branch 'fix/function-arg-conversion' into 'develop'

Fix function argument conversion of B, N, Z and R to R^1 or R^1x1

See merge request !157
parents 30229104 e955e475
No related branches found
No related tags found
1 merge request!157Fix function argument conversion of B, N, Z and R to R^1 or R^1x1
...@@ -15,7 +15,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy ...@@ -15,7 +15,7 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy
{ {
const size_t parameter_id = std::get<size_t>(parameter_symbol.attributes().value()); const size_t parameter_id = std::get<size_t>(parameter_symbol.attributes().value());
ASTNodeNaturalConversionChecker{node_sub_data_type, parameter_symbol.attributes().dataType()}; ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{node_sub_data_type, parameter_symbol.attributes().dataType()};
auto get_function_argument_converter_for = auto get_function_argument_converter_for =
[&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> { [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> {
...@@ -78,13 +78,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy ...@@ -78,13 +78,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy
// LCOV_EXCL_STOP // LCOV_EXCL_STOP
} }
} }
case ASTNodeDataType::bool_t: {
if ((parameter_v.dimension() == 1)) {
return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, bool>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument dimension",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::int_t: { case ASTNodeDataType::int_t: {
if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { if ((parameter_v.dimension() == 1)) {
return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, int64_t>>(parameter_id);
} else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) {
if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) {
return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ZeroType>>(parameter_id); return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ZeroType>>(parameter_id);
} }
} }
[[fallthrough]]; // LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument type",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
case ASTNodeDataType::unsigned_int_t: {
if ((parameter_v.dimension() == 1)) {
return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, uint64_t>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument dimension",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::double_t: {
if ((parameter_v.dimension() == 1)) {
return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, double>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument dimension",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
} }
// LCOV_EXCL_START // LCOV_EXCL_START
default: { default: {
...@@ -110,13 +145,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy ...@@ -110,13 +145,48 @@ ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_sy
// LCOV_EXCL_STOP // LCOV_EXCL_STOP
} }
} }
case ASTNodeDataType::bool_t: {
if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) {
return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, bool>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument type",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::int_t: { case ASTNodeDataType::int_t: {
if (node_sub_data_type.m_parent_node.is_type<language::integer>()) { if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) {
return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, int64_t>>(parameter_id);
} else if (node_sub_data_type.m_parent_node.is_type<language::integer>()) {
if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) { if (std::stoi(node_sub_data_type.m_parent_node.string()) == 0) {
return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id); return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, ZeroType>>(parameter_id);
} }
} }
[[fallthrough]]; // LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument type",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
case ASTNodeDataType::unsigned_int_t: {
if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) {
return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, uint64_t>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument type",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::double_t: {
if ((parameter_v.numberOfRows() == 1) and (parameter_v.numberOfColumns() == 1)) {
return std::make_unique<FunctionTinyMatrixArgumentConverter<ParameterT, double>>(parameter_id);
} else {
// LCOV_EXCL_START
throw ParseError("unexpected error: invalid argument type",
std::vector{node_sub_data_type.m_parent_node.begin()});
// LCOV_EXCL_STOP
}
} }
// LCOV_EXCL_START // LCOV_EXCL_START
default: { default: {
......
...@@ -116,11 +116,24 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver ...@@ -116,11 +116,24 @@ class FunctionTinyVectorArgumentConverter final : public IFunctionArgumentConver
value); value);
} else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero};
} else if constexpr (std::is_same_v<ExpectedValueType, TinyVector<1>>) {
if constexpr (std::is_same_v<ProvidedValueType, bool>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, int64_t>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, uint64_t>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, double>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else { } else {
static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>); static_assert(std::is_same_v<ExpectedValueType, TinyVector<1>>);
exec_policy.currentContext()[m_argument_id] = exec_policy.currentContext()[m_argument_id] =
std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value)));
} }
} else {
throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" +
demangle<ExpectedValueType>() + "'");
}
return {}; return {};
} }
...@@ -165,11 +178,25 @@ class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConver ...@@ -165,11 +178,25 @@ class FunctionTinyMatrixArgumentConverter final : public IFunctionArgumentConver
value); value);
} else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) { } else if constexpr (std::is_same_v<ProvidedValueType, ZeroType>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero}; exec_policy.currentContext()[m_argument_id] = ExpectedValueType{ZeroType::zero};
} else if constexpr (std::is_same_v<ExpectedValueType, TinyMatrix<1>>) {
if constexpr (std::is_same_v<ProvidedValueType, bool>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, int64_t>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, uint64_t>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else if constexpr (std::is_same_v<ProvidedValueType, double>) {
exec_policy.currentContext()[m_argument_id] = ExpectedValueType(std::get<ProvidedValueType>(value));
} else { } else {
static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>); static_assert(std::is_same_v<ExpectedValueType, TinyMatrix<1>>);
exec_policy.currentContext()[m_argument_id] = exec_policy.currentContext()[m_argument_id] =
std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value))); std::move(static_cast<ExpectedValueType>(std::get<ProvidedValueType>(value)));
} }
} else {
throw UnexpectedError(std::string{"cannot convert '"} + demangle<ProvidedValueType>() + "' to '" +
demangle<ExpectedValueType>() + "'");
}
return {}; return {};
} }
......
...@@ -370,13 +370,30 @@ cat("foo", 2.5e-3); ...@@ -370,13 +370,30 @@ cat("foo", 2.5e-3);
let f : R^1 -> R^1, x -> x+x; let f : R^1 -> R^1, x -> x+x;
let x : R^1, x = 1; let x : R^1, x = 1;
f(x); f(x);
let n:N, n=1;
f(true);
f(n);
f(2);
f(1.4);
)"; )";
std::string_view result = R"( std::string_view result = R"(
(root:ASTNodeListProcessor) (root:ASTNodeListProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::name:x:NameProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::true_kw:ValueProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::name:n:NameProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::integer:2:ValueProcessor)
`-(language::function_evaluation:FunctionProcessor) `-(language::function_evaluation:FunctionProcessor)
+-(language::name:f:NameProcessor) +-(language::name:f:NameProcessor)
`-(language::name:x:NameProcessor) `-(language::real:1.4:ValueProcessor)
)"; )";
CHECK_AST(data, result); CHECK_AST(data, result);
...@@ -424,13 +441,30 @@ f(x); ...@@ -424,13 +441,30 @@ f(x);
let f : R^1x1 -> R^1x1, x -> x+x; let f : R^1x1 -> R^1x1, x -> x+x;
let x : R^1x1, x = 1; let x : R^1x1, x = 1;
f(x); f(x);
let n:N, n=1;
f(true);
f(n);
f(2);
f(1.4);
)"; )";
std::string_view result = R"( std::string_view result = R"(
(root:ASTNodeListProcessor) (root:ASTNodeListProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::name:x:NameProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::true_kw:ValueProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::name:n:NameProcessor)
+-(language::function_evaluation:FunctionProcessor)
| +-(language::name:f:NameProcessor)
| `-(language::integer:2:ValueProcessor)
`-(language::function_evaluation:FunctionProcessor) `-(language::function_evaluation:FunctionProcessor)
+-(language::name:f:NameProcessor) +-(language::name:f:NameProcessor)
`-(language::name:x:NameProcessor) `-(language::real:1.4:ValueProcessor)
)"; )";
CHECK_AST(data, result); CHECK_AST(data, result);
...@@ -1115,6 +1149,170 @@ prev(3 + .24); ...@@ -1115,6 +1149,170 @@ prev(3 + .24);
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"}); CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> Z"});
} }
SECTION("B -> R^2")
{
std::string_view data = R"(
let f : R^2 -> R^2, x -> x;
f(true);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^2"});
}
SECTION("N -> R^2")
{
std::string_view data = R"(
let f : R^2 -> R^2, x -> x;
let n : N, n = 2;
f(n);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^2"});
}
SECTION("Z -> R^2")
{
std::string_view data = R"(
let f : R^2 -> R^2, x -> x;
f(-2);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2"});
}
SECTION("R -> R^2")
{
std::string_view data = R"(
let f : R^2 -> R^2, x -> x;
f(1.3);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^2"});
}
SECTION("B -> R^3")
{
std::string_view data = R"(
let f : R^3 -> R^3, x -> x;
f(true);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^3"});
}
SECTION("N -> R^3")
{
std::string_view data = R"(
let f : R^3 -> R^3, x -> x;
let n : N, n = 2;
f(n);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^3"});
}
SECTION("Z -> R^3")
{
std::string_view data = R"(
let f : R^3 -> R^3, x -> x;
f(-2);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^3"});
}
SECTION("R -> R^3")
{
std::string_view data = R"(
let f : R^3 -> R^3, x -> x;
f(1.3);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^3"});
}
SECTION("B -> R^2x2")
{
std::string_view data = R"(
let f : R^2x2 -> R^2x2, x -> x;
f(true);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^2x2"});
}
SECTION("N -> R^2x2")
{
std::string_view data = R"(
let f : R^2x2 -> R^2x2, x -> x;
let n : N, n = 2;
f(n);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^2x2"});
}
SECTION("Z -> R^2x2")
{
std::string_view data = R"(
let f : R^2x2 -> R^2x2, x -> x;
f(-2);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^2x2"});
}
SECTION("R -> R^2x2")
{
std::string_view data = R"(
let f : R^2x2 -> R^2x2, x -> x;
f(1.3);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^2x2"});
}
SECTION("B -> R^3x3")
{
std::string_view data = R"(
let f : R^3x3 -> R^3x3, x -> x;
f(true);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: B -> R^3x3"});
}
SECTION("N -> R^3x3")
{
std::string_view data = R"(
let f : R^3x3 -> R^3x3, x -> x;
let n : N, n = 2;
f(n);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: N -> R^3x3"});
}
SECTION("Z -> R^3x3")
{
std::string_view data = R"(
let f : R^3x3 -> R^3x3, x -> x;
f(-2);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: Z -> R^3x3"});
}
SECTION("R -> R^3x3")
{
std::string_view data = R"(
let f : R^3x3 -> R^3x3, x -> x;
f(1.3);
)";
CHECK_EXPRESSION_BUILDER_THROWS_WITH(data, std::string{"invalid implicit conversion: R -> R^3x3"});
}
} }
SECTION("arguments invalid tuple -> R^d conversion") SECTION("arguments invalid tuple -> R^d conversion")
......
...@@ -374,6 +374,43 @@ let fx:R^1, fx = f(x); ...@@ -374,6 +374,43 @@ let fx:R^1, fx = f(x);
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{3})); CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{3}));
} }
SECTION(" R^1 -> R^1 called with B argument")
{
std::string_view data = R"(
let f : R^1 -> R^1, x -> 2*x;
let fx:R^1, fx = f(true);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{true}));
}
SECTION(" R^1 -> R^1 called with N argument")
{
std::string_view data = R"(
let f : R^1 -> R^1, x -> 2*x;
let n:N, n = 3;
let fx:R^1, fx = f(n);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{3}));
}
SECTION(" R^1 -> R^1 called with Z argument")
{
std::string_view data = R"(
let f : R^1 -> R^1, x -> 2*x;
let fx:R^1, fx = f(-2);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{-2}));
}
SECTION(" R^1 -> R^1 called with R argument")
{
std::string_view data = R"(
let f : R^1 -> R^1, x -> 2*x;
let fx:R^1, fx = f(1.3);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyVector<1>{1.3}));
}
SECTION(" R^2 -> R^2") SECTION(" R^2 -> R^2")
{ {
std::string_view data = R"( std::string_view data = R"(
...@@ -439,6 +476,43 @@ let fx:R^1x1, fx = f(x); ...@@ -439,6 +476,43 @@ let fx:R^1x1, fx = f(x);
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3})); CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3}));
} }
SECTION(" R^1x1 -> R^1x1 called with B argument")
{
std::string_view data = R"(
let f : R^1x1 -> R^1x1, x -> 2*x;
let fx:R^1x1, fx = f(true);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{true}));
}
SECTION(" R^1x1 -> R^1x1 called with N argument")
{
std::string_view data = R"(
let f : R^1x1 -> R^1x1, x -> 2*x;
let n:N, n = 3;
let fx:R^1x1, fx = f(n);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{3}));
}
SECTION(" R^1x1 -> R^1x1 called with Z argument")
{
std::string_view data = R"(
let f : R^1x1 -> R^1x1, x -> 2*x;
let fx:R^1x1, fx = f(-4);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{-4}));
}
SECTION(" R^1x1 -> R^1x1 called with R argument")
{
std::string_view data = R"(
let f : R^1x1 -> R^1x1, x -> 2*x;
let fx:R^1x1, fx = f(-2.3);
)";
CHECK_FUNCTION_EVALUATION_RESULT(data, "fx", (2 * TinyMatrix<1>{-2.3}));
}
SECTION(" R^2x2 -> R^2x2") SECTION(" R^2x2 -> R^2x2")
{ {
std::string_view data = R"( std::string_view data = R"(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment