diff --git a/src/language/ast/ASTNodeDataTypeFlattener.cpp b/src/language/ast/ASTNodeDataTypeFlattener.cpp index 51d055836393b1ee680b7d983756bee39020d992..6a06c72bb66926af546ff0a130315eacc8b54c9a 100644 --- a/src/language/ast/ASTNodeDataTypeFlattener.cpp +++ b/src/language/ast/ASTNodeDataTypeFlattener.cpp @@ -1,6 +1,7 @@ #include <language/ast/ASTNodeDataTypeFlattener.hpp> #include <language/PEGGrammar.hpp> +#include <language/utils/BuiltinFunctionEmbedder.hpp> #include <language/utils/FunctionTable.hpp> #include <language/utils/SymbolTable.hpp> @@ -33,9 +34,20 @@ ASTNodeDataTypeFlattener::ASTNodeDataTypeFlattener(ASTNode& node, FlattenedDataT } break; } + case ASTNodeDataType::builtin_function_t: { + uint64_t builtin_function_id = std::get<uint64_t>(i_function_symbol->attributes().value()); + auto builtin_function_embedder = node.m_symbol_table->builtinFunctionEmbedderTable()[builtin_function_id]; + + const auto& compound_data_type = builtin_function_embedder->getReturnDataType(); + for (auto data_type : compound_data_type.contentTypeList()) { + flattened_datatype_list.push_back({*data_type, node}); + } + + break; + } // LCOV_EXCL_START default: { - throw ParseError("unexpected function type", node.begin()); + throw ParseError{"unexpected function type", node.begin()}; } // LCOV_EXCL_STOP } diff --git a/src/language/utils/BuiltinFunctionEmbedder.hpp b/src/language/utils/BuiltinFunctionEmbedder.hpp index afb991acadcb36236bcd55a34c7840ca399ab6f0..f85325c9b966f72443312258aec7973236c1f186 100644 --- a/src/language/utils/BuiltinFunctionEmbedder.hpp +++ b/src/language/utils/BuiltinFunctionEmbedder.hpp @@ -130,23 +130,30 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } - template <size_t I> + template <typename TupleT, size_t I> PUGS_INLINE ASTNodeDataType - _getOneParameterDataType(ArgsTuple& t) const + _getOneParameterDataType() const { - using ArgN_T = std::decay_t<decltype(std::get<I>(t))>; + using ArgN_T = std::decay_t<decltype(std::get<I>(TupleT{}))>; return _getDataType<ArgN_T>(); } template <size_t... I> - PUGS_INLINE std::vector<ASTNodeDataType> - _getParameterDataTypes(ArgsTuple t, std::index_sequence<I...>) const + PUGS_INLINE std::vector<ASTNodeDataType> _getParameterDataTypes(std::index_sequence<I...>) const { std::vector<ASTNodeDataType> parameter_type_list; - (parameter_type_list.push_back(this->_getOneParameterDataType<I>(t)), ...); + (parameter_type_list.push_back(this->_getOneParameterDataType<ArgsTuple, I>()), ...); return parameter_type_list; } + template <size_t... I> + PUGS_INLINE std::vector<std::shared_ptr<const ASTNodeDataType>> _getCompoundDataTypes(std::index_sequence<I...>) const + { + std::vector<std::shared_ptr<const ASTNodeDataType>> compound_type_list; + (compound_type_list.push_back(std::make_shared<ASTNodeDataType>(this->_getOneParameterDataType<FX, I>())), ...); + return compound_type_list; + } + template <typename T> PUGS_INLINE std::shared_ptr<IDataHandler> _createHandler(std::shared_ptr<T> data) const @@ -175,21 +182,40 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder (_check_arg<I>(), ...); } + PUGS_INLINE + AggregateDataVariant + _applyToAggregate(const ArgsTuple& t) const + { + auto tuple_result = std::apply(m_f, t); + std::vector<DataVariant> vector_result; + vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); + + std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + + return vector_result; + } + public: PUGS_INLINE ASTNodeDataType getReturnDataType() const final { - return _getDataType<FX>(); + if constexpr (is_std_tuple_v<FX>) { + constexpr size_t N = std::tuple_size_v<FX>; + using IndexSequence = std::make_index_sequence<N>; + + return ASTNodeDataType::build<ASTNodeDataType::list_t>(this->_getCompoundDataTypes(IndexSequence{})); + } else { + return this->_getDataType<FX>(); + } } PUGS_INLINE std::vector<ASTNodeDataType> getParameterDataTypes() const final { - constexpr size_t N = std::tuple_size_v<ArgsTuple>; - ArgsTuple t; + constexpr size_t N = std::tuple_size_v<ArgsTuple>; using IndexSequence = std::make_index_sequence<N>; - return this->_getParameterDataTypes(t, IndexSequence{}); + return this->_getParameterDataTypes(IndexSequence{}); } PUGS_INLINE size_t @@ -198,7 +224,6 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder return sizeof...(Args); } - public: PUGS_INLINE DataVariant apply(const std::vector<DataVariant>& x) const final @@ -210,6 +235,8 @@ class BuiltinFunctionEmbedder<FX(Args...)> : public IBuiltinFunctionEmbedder this->_copyFromVector(t, x, IndexSequence{}); if constexpr (is_data_variant_v<FX>) { return {std::apply(m_f, t)}; + } else if constexpr (is_std_tuple_v<FX>) { + return this->_applyToAggregate(t); } else if constexpr (std::is_same_v<FX, void>) { std::apply(m_f, t); return {}; @@ -245,6 +272,22 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder private: std::function<FX(void)> m_f; + template <typename TupleT, size_t I> + PUGS_INLINE ASTNodeDataType + _getOneParameterDataType() const + { + using ArgN_T = std::decay_t<decltype(std::get<I>(TupleT{}))>; + return _getDataType<ArgN_T>(); + } + + template <size_t... I> + PUGS_INLINE std::vector<std::shared_ptr<const ASTNodeDataType>> _getCompoundDataTypes(std::index_sequence<I...>) const + { + std::vector<std::shared_ptr<const ASTNodeDataType>> compound_type_list; + (compound_type_list.push_back(std::make_shared<ASTNodeDataType>(this->_getOneParameterDataType<FX, I>())), ...); + return compound_type_list; + } + template <typename T> PUGS_INLINE std::shared_ptr<IDataHandler> _createHandler(std::shared_ptr<T> data) const @@ -260,11 +303,30 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder return ast_node_data_type_from<T>; } + PUGS_INLINE + AggregateDataVariant + _applyToAggregate() const + { + auto tuple_result = m_f(); + std::vector<DataVariant> vector_result; + vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>); + + std::apply([&](auto&&... result) { ((vector_result.emplace_back(std::move(result))), ...); }, tuple_result); + + return vector_result; + } + public: PUGS_INLINE ASTNodeDataType getReturnDataType() const final { - return this->_getDataType<FX>(); + if constexpr (is_std_tuple_v<FX>) { + constexpr size_t N = std::tuple_size_v<FX>; + using IndexSequence = std::make_index_sequence<N>; + return ASTNodeDataType::build<ASTNodeDataType::list_t>(this->_getCompoundDataTypes(IndexSequence{})); + } else { + return this->_getDataType<FX>(); + } } PUGS_INLINE std::vector<ASTNodeDataType> @@ -285,6 +347,8 @@ class BuiltinFunctionEmbedder<FX(void)> : public IBuiltinFunctionEmbedder { if constexpr (is_data_variant_v<FX>) { return {m_f()}; + } else if constexpr (is_std_tuple_v<FX>) { + return this->_applyToAggregate(); } else if constexpr (std::is_same_v<FX, void>) { m_f(); return {}; diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index 2651c0e94659c1bbffc0bd47e6ab946cce18e3c9..a0050aef2064e466ce550e9dae70bb1c533ef465 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -286,6 +286,59 @@ TEST_CASE("BuiltinFunctionEmbedder", "[language]") REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t); } + SECTION("R*R -> R*R^2 BuiltinFunctionEmbedder") + { + std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>> { + return std::make_tuple(a + b, TinyVector<2>{b, -a}); + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = + std::make_unique<BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>>(double, double)>>(c); + + const double a = 3.2; + const double b = 1.5; + + REQUIRE(a + b == std::get<0>(c(a, b))); + REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b))); + const AggregateDataVariant value_list = + std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b})); + + REQUIRE(std::get<double>(value_list[0]) == a + b); + REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a}); + + auto data_type = i_embedded_c->getReturnDataType(); + REQUIRE(data_type == ASTNodeDataType::list_t); + + REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t); + REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)); + } + + SECTION("void -> N*R BuiltinFunctionEmbedder") + { + std::function c = [](void) -> std::tuple<uint64_t, double> { + uint64_t a = 1; + double b = 3.5; + return std::make_tuple(a, b); + }; + + std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = + std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double>(void)>>(c); + + REQUIRE(1ul == std::get<0>(c())); + REQUIRE(3.5 == std::get<1>(c())); + const AggregateDataVariant value_list = + std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{})); + + REQUIRE(std::get<uint64_t>(value_list[0]) == 1ul); + REQUIRE(std::get<double>(value_list[1]) == 3.5); + + auto data_type = i_embedded_c->getReturnDataType(); + REQUIRE(data_type == ASTNodeDataType::list_t); + + REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t); + REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t); + } + SECTION("void(void) BuiltinFunctionEmbedder") { std::function c = [](void) -> void {};