diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index f85325c9b966f72443312258aec7973236c1f186..265379a368ef3b7f847cd8de2b763a70b525f8a4 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -182,6 +182,17 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder (_check_arg<I>(), ...); } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + PUGS_INLINE AggregateDataVariant _applyToAggregate(const ArgsTuple& t) const @@ -190,7 +201,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder std::vector<DataVariant> vector_result; vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); - std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); return vector_result; } @@ -303,6 +315,17 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + PUGS_INLINE AggregateDataVariant _applyToAggregate() const @@ -311,7 +334,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder std::vector<DataVariant> vector_result; vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); - std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); return vector_result; } diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index a0050aef2064e466ce550e9dae70bb1c533ef465..ba88ac0b0a0e3d6e7c3a16977e03437a1e14a05c 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -286,46 +286,48 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); } - SECTION("R*R -> R*R^2 BuiltinFunctionEmbedder") + SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder") { - std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>> { - return std::make_tuple(a + b, TinyVector<2>{b, -a}); + std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> { + return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b)); }; - std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = - std::make_unique<BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>>(double, double)>>(c); + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique< + BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c); const double a = 3.2; const double b = 1.5; REQUIRE(a + b == std::get<0>(c(a, b))); REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b))); + REQUIRE(a - b == *std::get<2>(c(a, b))); const AggregateDataVariant value_list = std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b})); REQUIRE(std::get<double>(value_list[0]) == a + b); REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a}); - auto data_type = i_embedded_c->getReturnDataType(); REQUIRE(data_type == ASTNodeDataType::list_t); REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); } - SECTION("void -> N*R BuiltinFunctionEmbedder") + SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder") { - std::function c = [](void) -> std::tuple<uint64_t, double> { + std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> { uint64_t a = 1; double b = 3.5; - return std::make_tuple(a, b); + return std::make_tuple(a, b, std::make_shared<double>(a + b)); }; std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = - std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double>(void)>>(c); + std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c); REQUIRE(1ul == std::get<0>(c())); REQUIRE(3.5 == std::get<1>(c())); + REQUIRE((1ul + 3.5) == *std::get<2>(c())); const AggregateDataVariant value_list = std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{})); @@ -337,6 +339,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); } SECTION("void(void) BuiltinFunctionEmbedder")