Skip to content
Snippets Groups Projects
Commit 4e32168f authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Fix the case of EmbeddedData in compound returned type

parent 612d8e35
No related branches found
No related tags found
1 merge request!76Feature/builtin functions
...@@ -182,6 +182,17 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder ...@@ -182,6 +182,17 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder
(_check_arg<I>(), ...); (_check_arg<I>(), ...);
} }
template <typename ResultT>
PUGS_INLINE DataVariant
_resultToDataVariant(ResultT&& result) const
{
if constexpr (is_data_variant_v<std::decay_t<ResultT>>) {
return std::move(result);
} else {
return EmbeddedData(_createHandler(std::move(result)));
}
}
PUGS_INLINE PUGS_INLINE
AggregateDataVariant AggregateDataVariant
_applyToAggregate(const ArgsTuple& t) const _applyToAggregate(const ArgsTuple& t) const
...@@ -190,7 +201,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder ...@@ -190,7 +201,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder
std::vector<DataVariant> vector_result; std::vector<DataVariant> vector_result;
vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);
std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); },
tuple_result);
return vector_result; return vector_result;
} }
...@@ -303,6 +315,17 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder ...@@ -303,6 +315,17 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder
return ast_node_data_type_from<T>; return ast_node_data_type_from<T>;
} }
template <typename ResultT>
PUGS_INLINE DataVariant
_resultToDataVariant(ResultT&& result) const
{
if constexpr (is_data_variant_v<std::decay_t<ResultT>>) {
return std::move(result);
} else {
return EmbeddedData(_createHandler(std::move(result)));
}
}
PUGS_INLINE PUGS_INLINE
AggregateDataVariant AggregateDataVariant
_applyToAggregate() const _applyToAggregate() const
...@@ -311,7 +334,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder ...@@ -311,7 +334,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder
std::vector<DataVariant> vector_result; std::vector<DataVariant> vector_result;
vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);
std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); },
tuple_result);
return vector_result; return vector_result;
} }
......
...@@ -286,46 +286,48 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") ...@@ -286,46 +286,48 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]")
REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);
} }
SECTION("R*R -> R*R^2 BuiltinFunctionEmbedder") SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder")
{ {
std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>> { std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> {
return std::make_tuple(a + b, TinyVector<2>{b, -a}); return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b));
}; };
std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<
std::make_unique<BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>>(double, double)>>(c); BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c);
const double a = 3.2; const double a = 3.2;
const double b = 1.5; const double b = 1.5;
REQUIRE(a + b == std::get<0>(c(a, b))); REQUIRE(a + b == std::get<0>(c(a, b)));
REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b))); REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b)));
REQUIRE(a - b == *std::get<2>(c(a, b)));
const AggregateDataVariant value_list = const AggregateDataVariant value_list =
std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b})); std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b}));
REQUIRE(std::get<double>(value_list[0]) == a + b); REQUIRE(std::get<double>(value_list[0]) == a + b);
REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a}); REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a});
auto data_type = i_embedded_c->getReturnDataType(); auto data_type = i_embedded_c->getReturnDataType();
REQUIRE(data_type == ASTNodeDataType::list_t); REQUIRE(data_type == ASTNodeDataType::list_t);
REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t); REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t);
REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2));
REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
} }
SECTION("void -> N*R BuiltinFunctionEmbedder") SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder")
{ {
std::function c = [](void) -> std::tuple<uint64_t, double> { std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> {
uint64_t a = 1; uint64_t a = 1;
double b = 3.5; double b = 3.5;
return std::make_tuple(a, b); return std::make_tuple(a, b, std::make_shared<double>(a + b));
}; };
std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double>(void)>>(c); std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c);
REQUIRE(1ul == std::get<0>(c())); REQUIRE(1ul == std::get<0>(c()));
REQUIRE(3.5 == std::get<1>(c())); REQUIRE(3.5 == std::get<1>(c()));
REQUIRE((1ul + 3.5) == *std::get<2>(c()));
const AggregateDataVariant value_list = const AggregateDataVariant value_list =
std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{})); std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{}));
...@@ -337,6 +339,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") ...@@ -337,6 +339,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]")
REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t);
REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t);
REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
} }
SECTION("void(void) BuiltinFunctionEmbedder") SECTION("void(void) BuiltinFunctionEmbedder")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment