diff --git a/src/language/ast/ASTNodeDataTypeFlattener.cpp b/src/language/ast/ASTNodeDataTypeFlattener.cpp index 51d055836393b1ee680b7d983756bee39020d992..6a06c72bb66926af546ff0a130315eacc8b54c9a 100644 --- a/src/language/ast/ASTNodeDataTypeFlattener.cpp +++ b/src/language/ast/ASTNodeDataTypeFlattener.cpp @@ -1,6 +1,7 @@ #include <language/ast/ASTNodeDataTypeFlattener.hpp> #include <language/PEGGrammar.hpp> +#include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/FunctionTable.hpp> #include <language/utils/SymbolTable.hpp> @@ -33,9 +34,20 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT } break; } + case ASTNodeDataType::builtin_function_t: { + uint64_t builtin_function_id = std::get<uint64_t>(i_function_symbol->attributes().value()); + auto builtin_function_embedder = node.m_symbol_table->builtinFunctionEmbedderTable()[builtin_function_id]; + + const auto& compound_data_type = builtin_function_embedder->getReturnDataType(); + for (auto data_type : compound_data_type.contentTypeList()) { + flattened_datatype_list.push_back({*data_type, node}); + } + + break; + } // LCOV_EXCL_START default: { - throw ParseError("unexpected function type", node.begin()); + throw ParseError{"unexpected function type", node.begin()}; } // LCOV_EXCL_STOP } diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index afb991acadcb36236bcd55a34c7840ca399ab6f0..265379a368ef3b7f847cd8de2b763a70b525f8a4 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -130,23 +130,30 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } - template <size_t I> + template <typename TupleT, size_t I> PUGS_INLINE ASTNodeDataType - _getOneParameterDataType(ArgsTuple& t) const + _getOneParameterDataType() const { - using ArgN_T = std::decay_t<decltype(std::get<I>(t))>; + using ArgN_T = std::decay_t<decltype(std::get<I>(TupleT{}))>; return _getDataType<ArgN_T>(); } template <size_t... I> - PUGS_INLINE std::vector<ASTNodeDataType> - _getParameterDataTypes(ArgsTuple t, std::index_sequence<I...>) const + PUGS_INLINE std::vector<ASTNodeDataType> _getParameterDataTypes(std::index_sequence<I...>) const { std::vector<ASTNodeDataType> parameter_type_list; - (parameter_type_list.push_back(this->_getOneParameterDataType<I>(t)), ...); + (parameter_type_list.push_back(this->_getOneParameterDataType<ArgsTuple, I>()), ...); return parameter_type_list; } + template <size_t... I> + PUGS_INLINE std::vector<std::shared_ptr<const ASTNodeDataType>> _getCompoundDataTypes(std::index_sequence<I...>) const + { + std::vector<std::shared_ptr<const ASTNodeDataType>> compound_type_list; + (compound_type_list.push_back(std::make_shared<ASTNodeDataType>(this->_getOneParameterDataType<FX, I>())), ...); + return compound_type_list; + } + template <typename T> PUGS_INLINE std::shared_ptr<IDataHandler> _createHandler(std::shared_ptr<T> data) const @@ -175,21 +182,52 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder (_check_arg<I>(), ...); } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + + PUGS_INLINE + AggregateDataVariant + _applyToAggregate(const ArgsTuple& t) const + { + auto tuple_result = std::apply(m_f, t); + std::vector<DataVariant> vector_result; + vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); + + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); + + return vector_result; + } + public: PUGS_INLINE ASTNodeDataType getReturnDataType() const final { - return _getDataType<FX>(); + if constexpr (is_std_tuple_v<FX>) { + constexpr size_t N = std::tuple_size_v<FX>; + using IndexSequence = std::make_index_sequence<N>; + + return ASTNodeDataType::build<ASTNodeDataType::list_t>(this->_getCompoundDataTypes(IndexSequence{})); + } else { + return this->_getDataType<FX>(); + } } PUGS_INLINE std::vector<ASTNodeDataType> getParameterDataTypes() const final { - constexpr size_t N = std::tuple_size_v<ArgsTuple>; - ArgsTuple t; + constexpr size_t N = std::tuple_size_v<ArgsTuple>; using IndexSequence = std::make_index_sequence<N>; - return this->_getParameterDataTypes(t, IndexSequence{}); + return this->_getParameterDataTypes(IndexSequence{}); } PUGS_INLINE size_t @@ -198,7 +236,6 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder return sizeof...(Args); } - public: PUGS_INLINE DataVariant apply(const std::vector<DataVariant>& x) const final @@ -210,6 +247,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder this->_copyFromVector(t, x, IndexSequence{}); if constexpr (is_data_variant_v<FX>) { return {std::apply(m_f, t)}; + } else if constexpr (is_std_tuple_v<FX>) { + return this->_applyToAggregate(t); } else if constexpr (std::is_same_v<FX, void>) { std::apply(m_f, t); return {}; @@ -245,6 +284,22 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder private: std::function<FX(void)> m_f; + template <typename TupleT, size_t I> + PUGS_INLINE ASTNodeDataType + _getOneParameterDataType() const + { + using ArgN_T = std::decay_t<decltype(std::get<I>(TupleT{}))>; + return _getDataType<ArgN_T>(); + } + + template <size_t... I> + PUGS_INLINE std::vector<std::shared_ptr<const ASTNodeDataType>> _getCompoundDataTypes(std::index_sequence<I...>) const + { + std::vector<std::shared_ptr<const ASTNodeDataType>> compound_type_list; + (compound_type_list.push_back(std::make_shared<ASTNodeDataType>(this->_getOneParameterDataType<FX, I>())), ...); + return compound_type_list; + } + template <typename T> PUGS_INLINE std::shared_ptr<IDataHandler> _createHandler(std::shared_ptr<T> data) const @@ -260,11 +315,42 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } + template <typename ResultT> + PUGS_INLINE DataVariant + _resultToDataVariant(ResultT&& result) const + { + if constexpr (is_data_variant_v<std::decay_t<ResultT>>) { + return std::move(result); + } else { + return EmbeddedData(_createHandler(std::move(result))); + } + } + + PUGS_INLINE + AggregateDataVariant + _applyToAggregate() const + { + auto tuple_result = m_f(); + std::vector<DataVariant> vector_result; + vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); + + std::apply([&](auto&&... result) { ((vector_result.emplace_back(_resultToDataVariant(result))), ...); }, + tuple_result); + + return vector_result; + } + public: PUGS_INLINE ASTNodeDataType getReturnDataType() const final { - return this->_getDataType<FX>(); + if constexpr (is_std_tuple_v<FX>) { + constexpr size_t N = std::tuple_size_v<FX>; + using IndexSequence = std::make_index_sequence<N>; + return ASTNodeDataType::build<ASTNodeDataType::list_t>(this->_getCompoundDataTypes(IndexSequence{})); + } else { + return this->_getDataType<FX>(); + } } PUGS_INLINE std::vector<ASTNodeDataType> @@ -285,6 +371,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder { if constexpr (is_data_variant_v<FX>) { return {m_f()}; + } else if constexpr (is_std_tuple_v<FX>) { + return this->_applyToAggregate(); } else if constexpr (std::is_same_v<FX, void>) { m_f(); return {}; diff --git a/src/utils/PugsTraits.hpp b/src/utils/PugsTraits.hpp index f51b455cb5ea4efa678dd980d1fec2721ab1fd2a..e1dc8b48d6b15fa05ac9b9b9378385e351fadb2b 100644 --- a/src/utils/PugsTraits.hpp +++ b/src/utils/PugsTraits.hpp @@ -3,6 +3,7 @@ #include <cstddef> #include <memory> +#include <tuple> #include <type_traits> #include <variant> #include <vector> @@ -40,7 +41,7 @@ inline constexpr bool is_shared_ptr_v = false; template <typename T> inline constexpr bool is_shared_ptr_v<std::shared_ptr<T>> = true; -// Traits is_shared_ptr +// Traits is_unique_ptr template <typename T> inline constexpr bool is_unique_ptr_v = false; @@ -76,6 +77,14 @@ inline constexpr bool is_std_vector_v = false; template <typename T> inline constexpr bool is_std_vector_v<std::vector<T>> = true; +// Traits is_std_tuple + +template <typename... T> +inline constexpr bool is_std_tuple_v = false; + +template <typename... T> +inline constexpr bool is_std_tuple_v<std::tuple<T...>> = true; + // Traits is_tiny_vector template <typename T> diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 2651c0e94659c1bbffc0bd47e6ab946cce18e3c9..ba88ac0b0a0e3d6e7c3a16977e03437a1e14a05c 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -286,6 +286,62 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); } + SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder") + { + std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> { + return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b)); + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique< + BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c); + + const double a = 3.2; + const double b = 1.5; + + REQUIRE(a + b == std::get<0>(c(a, b))); + REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b))); + REQUIRE(a - b == *std::get<2>(c(a, b))); + const AggregateDataVariant value_list = + std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b})); + + REQUIRE(std::get<double>(value_list[0]) == a + b); + REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a}); + auto data_type = i_embedded_c->getReturnDataType(); + REQUIRE(data_type == ASTNodeDataType::list_t); + + REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t); + REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); + } + + SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder") + { + std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> { + uint64_t a = 1; + double b = 3.5; + return std::make_tuple(a, b, std::make_shared<double>(a + b)); + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = + std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c); + + REQUIRE(1ul == std::get<0>(c())); + REQUIRE(3.5 == std::get<1>(c())); + REQUIRE((1ul + 3.5) == *std::get<2>(c())); + const AggregateDataVariant value_list = + std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{})); + + REQUIRE(std::get<uint64_t>(value_list[0]) == 1ul); + REQUIRE(std::get<double>(value_list[1]) == 3.5); + + auto data_type = i_embedded_c->getReturnDataType(); + REQUIRE(data_type == ASTNodeDataType::list_t); + + REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); + REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t); + REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>); + } + SECTION("void(void) BuiltinFunctionEmbedder") { std::function c = [](void) -> void {};