From 4e32168f5c1a3c860afb0c559547ffbbcd3e15f6 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 23 Feb 2021 11:29:20 +0100
Subject: [PATCH] Fix the case of EmbeddedData in compound returned type

---
 .../utils/BuiltinFunctionEmbedder.hpp         | 28 +++++++++++++++++--
 tests/test_BuiltinFunctionEmbedder.cpp        | 23 ++++++++-------
 2 files changed, 39 insertions(+), 12 deletions(-)

diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp
index f85325c9b..265379a36 100644
--- a/src/language/utils/BuiltinFunctionEmbedder.hpp
+++ b/src/language/utils/BuiltinFunctionEmbedder.hpp
@@ -182,6 +182,17 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder
     (_check_arg<I>(), ...);
   }
 
+  template <typename ResultT>
+  PUGS_INLINE DataVariant
+  _resultToDataVariant(ResultT&& result) const
+  {
+    if constexpr (is_data_variant_v<std::decay_t<ResultT>>) {
+      return std::move(result);
+    } else {
+      return EmbeddedData(_createHandler(std::move(result)));
+    }
+  }
+
   PUGS_INLINE
   AggregateDataVariant
   _applyToAggregate(const ArgsTuple& t) const
@@ -190,7 +201,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder
     std::vector<DataVariant> vector_result;
     vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);
 
-    std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result);
+    std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); },
+               tuple_result);
 
     return vector_result;
   }
@@ -303,6 +315,17 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder
     return ast_node_data_type_from<T>;
   }
 
+  template <typename ResultT>
+  PUGS_INLINE DataVariant
+  _resultToDataVariant(ResultT&& result) const
+  {
+    if constexpr (is_data_variant_v<std::decay_t<ResultT>>) {
+      return std::move(result);
+    } else {
+      return EmbeddedData(_createHandler(std::move(result)));
+    }
+  }
+
   PUGS_INLINE
   AggregateDataVariant
   _applyToAggregate() const
@@ -311,7 +334,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder
     std::vector<DataVariant> vector_result;
     vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);
 
-    std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result);
+    std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); },
+               tuple_result);
 
     return vector_result;
   }
diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp
index a0050aef2..ba88ac0b0 100644
--- a/tests/test_BuiltinFunctionEmbedder.cpp
+++ b/tests/test_BuiltinFunctionEmbedder.cpp
@@ -286,46 +286,48 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]")
     REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);
   }
 
-  SECTION("R*R -> R*R^2 BuiltinFunctionEmbedder")
+  SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder")
   {
-    std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>> {
-      return std::make_tuple(a + b, TinyVector<2>{b, -a});
+    std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> {
+      return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b));
     };
 
-    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
-      std::make_unique<BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>>(double, double)>>(c);
+    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<
+      BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c);
 
     const double a = 3.2;
     const double b = 1.5;
 
     REQUIRE(a + b == std::get<0>(c(a, b)));
     REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b)));
+    REQUIRE(a - b == *std::get<2>(c(a, b)));
     const AggregateDataVariant value_list =
       std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b}));
 
     REQUIRE(std::get<double>(value_list[0]) == a + b);
     REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a});
-
     auto data_type = i_embedded_c->getReturnDataType();
     REQUIRE(data_type == ASTNodeDataType::list_t);
 
     REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t);
     REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2));
+    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
   }
 
-  SECTION("void -> N*R BuiltinFunctionEmbedder")
+  SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder")
   {
-    std::function c = [](void) -> std::tuple<uint64_t, double> {
+    std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> {
       uint64_t a = 1;
       double b   = 3.5;
-      return std::make_tuple(a, b);
+      return std::make_tuple(a, b, std::make_shared<double>(a + b));
     };
 
     std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
-      std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double>(void)>>(c);
+      std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c);
 
     REQUIRE(1ul == std::get<0>(c()));
     REQUIRE(3.5 == std::get<1>(c()));
+    REQUIRE((1ul + 3.5) == *std::get<2>(c()));
     const AggregateDataVariant value_list =
       std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{}));
 
@@ -337,6 +339,7 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]")
 
     REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t);
     REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t);
+    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
   }
 
   SECTION("void(void) BuiltinFunctionEmbedder")
-- 
GitLab