diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 487259530b7c3e4872371e9623517c0546d31e4f..412906b9d67291ed7901d3ac00943cc8a9107449 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -231,6 +231,9 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } case ASTNodeDataType::double_t: { return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, double>>(argument_number); + } + case ASTNodeDataType::function_t: { + return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, FunctionSymbolId>>(argument_number); } // LCOV_EXCL_START default: { @@ -240,7 +243,15 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } } case ASTNodeDataType::list_t: { - return std::make_unique<FunctionListArgumentConverter<ParameterContentT, ParameterContentT>>(argument_number); + if constexpr (std::is_same_v<ParameterContentT, FunctionSymbolId>) { + const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node; + auto symbol_table = parent_node.m_symbol_table; + + return std::make_unique<FunctionListArgumentConverter<FunctionSymbolId, FunctionSymbolId>>(argument_number, + symbol_table); + } else { + return std::make_unique<FunctionListArgumentConverter<ParameterContentT, ParameterContentT>>(argument_number); + } } case ASTNodeDataType::type_id_t: { return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, EmbeddedData>>(argument_number); @@ -257,6 +268,12 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData case ASTNodeDataType::double_t: { return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, double>>(argument_number); } + case ASTNodeDataType::function_t: { + const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node; + auto symbol_table = parent_node.m_symbol_table; + + return std::make_unique<FunctionArgumentToTupleFunctionSymbolIdConverter>(argument_number, symbol_table); + } case ASTNodeDataType::vector_t: { switch (arg_data_type.dimension()) { case 1: { @@ -398,6 +415,9 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData case ASTNodeDataType::double_t: { return get_function_argument_to_tuple_converter(double{}); } + case ASTNodeDataType::function_t: { + return get_function_argument_to_tuple_converter(FunctionSymbolId{}); + } case ASTNodeDataType::vector_t: { switch (parameter_type.contentType().dimension()) { case 1: { diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index af08e7806a4fd5daa01a7aecc98a84e52291040f..b13eb222b08cd2c0f5dde5e72d4de5626ef7758b 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -242,6 +242,7 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter DataVariant convert(ExecutionPolicy& exec_policy, DataVariant&& value) { + static_assert(not std::is_same_v<ContentType, FunctionSymbolId>); using TupleType = std::vector<ContentType>; if constexpr (std::is_same_v<ContentType, ProvidedValueType>) { std::visit( @@ -304,6 +305,57 @@ class FunctionListArgumentConverter final : public IFunctionArgumentConverter FunctionListArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} }; +template <> +class FunctionListArgumentConverter<FunctionSymbolId, FunctionSymbolId> final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + std::shared_ptr<SymbolTable> m_symbol_table; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + using TupleType = std::vector<FunctionSymbolId>; + std::visit( + [&](auto&& v) { + using ValueT = std::decay_t<decltype(v)>; + if constexpr (std::is_same_v<ValueT, AggregateDataVariant>) { + TupleType list_value; + list_value.reserve(v.size()); + for (size_t i = 0; i < v.size(); ++i) { + std::visit( + [&](auto&& vi) { + using Vi_T = std::decay_t<decltype(vi)>; + if constexpr (std::is_same_v<Vi_T, uint64_t>) { + list_value.emplace_back(FunctionSymbolId{vi, m_symbol_table}); + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<Vi_T>() + "' to '" + + demangle<FunctionSymbolId>() + "'"); + // LCOV_EXCL_STOP + } + }, + (v[i])); + } + exec_policy.currentContext()[m_argument_id] = std::move(list_value); + } else { + // LCOV_EXCL_START + throw UnexpectedError(std::string{"invalid conversion of '"} + demangle<ValueT>() + "' to '" + + demangle<FunctionSymbolId>() + "' list"); + // LCOV_EXCL_STOP + } + }, + value); + + return {}; + } + + FunctionListArgumentConverter(size_t argument_id, const std::shared_ptr<SymbolTable>& symbol_table) + : m_argument_id{argument_id}, m_symbol_table{symbol_table} + {} +}; + class FunctionArgumentToFunctionSymbolIdConverter final : public IFunctionArgumentConverter { private: @@ -324,4 +376,25 @@ class FunctionArgumentToFunctionSymbolIdConverter final : public IFunctionArgume {} }; +class FunctionArgumentToTupleFunctionSymbolIdConverter final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + std::shared_ptr<SymbolTable> m_symbol_table; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + exec_policy.currentContext()[m_argument_id] = + std::vector{FunctionSymbolId{std::get<uint64_t>(value), m_symbol_table}}; + + return {}; + } + + FunctionArgumentToTupleFunctionSymbolIdConverter(size_t argument_id, const std::shared_ptr<SymbolTable>& symbol_table) + : m_argument_id{argument_id}, m_symbol_table{symbol_table} + {} +}; + #endif // FUNCTION_ARGUMENT_CONVERTER_HPP diff --git a/src/language/utils/DataVariant.hpp b/src/language/utils/DataVariant.hpp index 6add9e52853bcbf58ee42e4f2869649dab43d37d..964044c6121478b4247c089a45eaf3ebe025195e 100644 --- a/src/language/utils/DataVariant.hpp +++ b/src/language/utils/DataVariant.hpp @@ -40,7 +40,8 @@ using DataVariant = std::variant<std::monostate, std::vector<TinyMatrix<3>>, std::vector<EmbeddedData>, AggregateDataVariant, - FunctionSymbolId>; + FunctionSymbolId, + std::vector<FunctionSymbolId>>; template <typename T, typename...> inline constexpr bool is_data_variant_v = is_variant<T, DataVariant>::value;