From 4e32168f5c1a3c860afb0c559547ffbbcd3e15f6 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 23 Feb 2021 11:29:20 +0100 Subject: [PATCH] Fix the case of EmbeddedData in compound returned type --- .../utils/BuiltinFunctionEmbedder.hpp | 28 +++++++++++++++++-- tests/test_BuiltinFunctionEmbedder.cpp | 23 ++++++++------- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index f85325c9b..265379a36 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -182,6 +182,17 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder (_check_arg<I>(), ...); } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + PUGS_INLINE AggregateDataVariant _applyToAggregate(const ArgsTuple& t) const @@ -190,7 +201,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder std::vector<DataVariant> vector_result; vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); - std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); return vector_result; } @@ -303,6 +315,17 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + PUGS_INLINE AggregateDataVariant _applyToAggregate() const @@ -311,7 +334,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder std::vector<DataVariant> vector_result; vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); - std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); return vector_result; } diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index a0050aef2..ba88ac0b0 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -286,46 +286,48 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); } - SECTION("R*R -> R*R^2 BuiltinFunctionEmbedder") + SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder") { - std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>> { - return std::make_tuple(a + b, TinyVector<2>{b, -a}); + std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> { + return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b)); }; - std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = - std::make_unique<BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>>(double, double)>>(c); + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique< + BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c); const double a = 3.2; const double b = 1.5; REQUIRE(a + b == std::get<0>(c(a, b))); REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b))); + REQUIRE(a - b == *std::get<2>(c(a, b))); const AggregateDataVariant value_list = std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b})); REQUIRE(std::get<double>(value_list[0]) == a + b); REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a}); - auto data_type = i_embedded_c->getReturnDataType(); REQUIRE(data_type == ASTNodeDataType::list_t); REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); } - SECTION("void -> N*R BuiltinFunctionEmbedder") + SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder") { - std::function c = [](void) -> std::tuple<uint64_t, double> { + std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> { uint64_t a = 1; double b = 3.5; - return std::make_tuple(a, b); + return std::make_tuple(a, b, std::make_shared<double>(a + b)); }; std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = - std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double>(void)>>(c); + std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c); REQUIRE(1ul == std::get<0>(c())); REQUIRE(3.5 == std::get<1>(c())); + REQUIRE((1ul + 3.5) == *std::get<2>(c())); const AggregateDataVariant value_list = std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{})); @@ -337,6 +339,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); } SECTION("void(void) BuiltinFunctionEmbedder") -- GitLab