diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp index d424c2a5c5d3b3584d1312a871c48ddc12c0e2d0..780145e145d2325187615d1f679dfdc0e0d29f72 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp @@ -6,6 +6,101 @@ #include <node_processor/FunctionProcessor.hpp> +template <typename SymbolType> +PUGS_INLINE std::unique_ptr<INodeProcessor> +ASTNodeFunctionExpressionBuilder::_getArgumentProcessor(ASTNode& argument_node, SymbolType& parameter_symbol) +{ + auto get_function_argument_processor_for_parameter_type = + [&](const auto& argument_v) -> std::unique_ptr<INodeProcessor> { + using ArgumentT = std::decay_t<decltype(argument_v)>; + switch (parameter_symbol.attributes().dataType()) { + case ASTNodeDataType::bool_t: { + return std::make_unique<FunctionArgumentProcessor<ArgumentT, bool>>(argument_node, parameter_symbol); + } + case ASTNodeDataType::unsigned_int_t: { + return std::make_unique<FunctionArgumentProcessor<ArgumentT, uint64_t>>(argument_node, parameter_symbol); + } + case ASTNodeDataType::int_t: { + return std::make_unique<FunctionArgumentProcessor<ArgumentT, int64_t>>(argument_node, parameter_symbol); + } + case ASTNodeDataType::double_t: { + return std::make_unique<FunctionArgumentProcessor<ArgumentT, double>>(argument_node, parameter_symbol); + } + default: { + throw parse_error("unexpected error: undefined parameter type for function", std::vector{argument_node.begin()}); + } + } + }; + + auto get_function_argument_processor_for_argument_type = [&]() { + switch (argument_node.m_data_type) { + case ASTNodeDataType::bool_t: { + return get_function_argument_processor_for_parameter_type(bool{}); + } + case ASTNodeDataType::unsigned_int_t: { + return get_function_argument_processor_for_parameter_type(uint64_t{}); + } + case ASTNodeDataType::int_t: { + return get_function_argument_processor_for_parameter_type(int64_t{}); + } + case ASTNodeDataType::double_t: { + return get_function_argument_processor_for_parameter_type(double{}); + } + default: { + throw parse_error("unexpected error: undefined argument type for function", std::vector{argument_node.begin()}); + } + } + }; + + return get_function_argument_processor_for_argument_type(); +} + +PUGS_INLINE +std::unique_ptr<INodeProcessor> +ASTNodeFunctionExpressionBuilder::_getArgumentProcessor(FunctionDescriptor& function_descriptor, ASTNode& argument_node) +{ + SymbolTable::Symbol& symbol{[&]() -> SymbolTable::Symbol& { + ASTNode& definition_node = function_descriptor.definitionNode(); + ASTNode& argument_variable = *definition_node.children[0]; + + if (argument_variable.is<language::name>()) { + auto [i_symbol, found] = + argument_variable.m_symbol_table->find(argument_variable.string(), argument_variable.begin()); + Assert(found); + + return *i_symbol; + } else { + throw parse_error("argument list not implemented yet!", function_descriptor.definitionNode().begin()); + } + }()}; + + return std::make_unique<FunctionArgumentProcessor<double, double>>(argument_node, symbol); +} + +PUGS_INLINE +void +ASTNodeFunctionExpressionBuilder::_buildArgumentProcessors(FunctionDescriptor& function_descriptor, + ASTNode& node, + FunctionProcessor& function_processor) +{ + ASTNode& definition_node = function_descriptor.definitionNode(); + ASTNode& parameter_variables = *definition_node.children[0]; + + ASTNode& argument_nodes = *node.children[1]; + + if (argument_nodes.is<language::expression_list>()) { + throw parse_error("argument list not implemented yet!", argument_nodes.begin()); + } else { + Assert(parameter_variables.is<language::name>(), "unexpected parameter type!"); + + auto [i_parameter_symbol, found] = + parameter_variables.m_symbol_table->find(parameter_variables.string(), parameter_variables.begin()); + Assert(found); + + function_processor.addArgumentProcessor(this->_getArgumentProcessor(argument_nodes, *i_parameter_symbol)); + } +} + PUGS_INLINE std::unique_ptr<INodeProcessor> ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType expression_value_type, @@ -59,9 +154,9 @@ ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType ex ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node) { - auto [i_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin()); + auto [i_function_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin()); Assert(found); - uint64_t function_id = std::get<uint64_t>(i_symbol->attributes().value()); + uint64_t function_id = std::get<uint64_t>(i_function_symbol->attributes().value()); FunctionDescriptor& function_descriptor = node.m_symbol_table->functionTable()[function_id]; @@ -70,13 +165,15 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node std::vector{function_descriptor.definitionNode().children[1]->begin()}); } #warning compute the right value type - const ASTNodeDataType return_value_type = ASTNodeDataType::double_t; + const ASTNodeDataType return_value_type = ASTNodeDataType::double_t; + const ASTNodeDataType expression_value_type = function_descriptor.definitionNode().children[1]->m_data_type; std::unique_ptr function_processor = std::make_unique<FunctionProcessor>(); + this->_buildArgumentProcessors(function_descriptor, node, *function_processor); + function_processor->addFunctionExpressionProcessor( - this->_getFunctionProcessor(function_descriptor.definitionNode().children[1]->m_data_type, return_value_type, - node)); + this->_getFunctionProcessor(expression_value_type, return_value_type, node)); node.m_node_processor = std::move(function_processor); } diff --git a/src/language/ASTNodeFunctionExpressionBuilder.hpp b/src/language/ASTNodeFunctionExpressionBuilder.hpp index dae92b7594ad05187daaef25014ca10e307a5843..f3c9a4fec65052eface478281a3597124556feae 100644 --- a/src/language/ASTNodeFunctionExpressionBuilder.hpp +++ b/src/language/ASTNodeFunctionExpressionBuilder.hpp @@ -4,9 +4,25 @@ #include <ASTNode.hpp> #include <node_processor/INodeProcessor.hpp> +class FunctionProcessor; +class FunctionDescriptor; + class ASTNodeFunctionExpressionBuilder { private: + template <typename SymbolType> + PUGS_INLINE std::unique_ptr<INodeProcessor> _getArgumentProcessor(ASTNode& argument_node, + SymbolType& parameter_symbol); + + PUGS_INLINE + std::unique_ptr<INodeProcessor> _getArgumentProcessor(FunctionDescriptor& function_descriptor, + ASTNode& argument_node); + + PUGS_INLINE + void _buildArgumentProcessors(FunctionDescriptor& function_descriptor, + ASTNode& node, + FunctionProcessor& function_processor); + PUGS_INLINE std::unique_ptr<INodeProcessor> _getFunctionProcessor(const ASTNodeDataType expression_value_type, const ASTNodeDataType return_value_type, diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp index 2fae6c16dced9f30490da4fafbe7d2854122d674..26cff7a0b68a15f8cad8737a8dd18e630c14c4e8 100644 --- a/src/language/node_processor/FunctionProcessor.hpp +++ b/src/language/node_processor/FunctionProcessor.hpp @@ -10,65 +10,59 @@ #include <node_processor/INodeProcessor.hpp> -template <typename ReturnType, typename ExpressionValueType> -class FunctionExpressionProcessor final : public INodeProcessor +template <typename ProvidedValueType, typename ExpectedValueType> +class FunctionArgumentProcessor final : public INodeProcessor { private: - ASTNode& m_node; - - FunctionDescriptor& m_function_descriptor; + ASTNode& m_provided_value_node; + ASTNodeDataVariant& m_symbol_value; + public: void - _executeArguments(ExecUntilBreakOrContinue& exec_policy) + execute(ExecUntilBreakOrContinue& exec_policy) { - // Compute arguments values - ASTNode& arguments_values = *m_node.children[1]; - arguments_values.execute(exec_policy); - - // Copy arguments to function arguments - ASTNode& definition_node = m_function_descriptor.definitionNode(); - ASTNode& arguments = *definition_node.children[0]; - - if (arguments.is<language::name>()) { - Assert(arguments_values.children.size() == 0); + m_provided_value_node.execute(exec_policy); - auto [i_symbol, found] = arguments.m_symbol_table->find(arguments.string(), arguments.begin()); - Assert(found); - - if (i_symbol->attributes().dataType() == arguments_values.m_data_type) { - i_symbol->attributes().value() = arguments_values.m_value; - } else { - throw parse_error("argument type conversion not implemented yet!", arguments_values.begin()); - } + if constexpr (std::is_same_v<ExpectedValueType, ProvidedValueType>) { + m_symbol_value = m_provided_value_node.m_value; } else { - throw parse_error("argument list not implemented yet!", arguments_values.begin()); + m_symbol_value = static_cast<ExpectedValueType>(std::get<ProvidedValueType>(m_provided_value_node.m_value)); } } + FunctionArgumentProcessor(ASTNode& provided_value_node, SymbolTable::Symbol& argument_symbol) + : m_provided_value_node{provided_value_node}, m_symbol_value{argument_symbol.attributes().value()} + {} +}; + +template <typename ReturnType, typename ExpressionValueType> +class FunctionExpressionProcessor final : public INodeProcessor +{ + private: + ASTNode& m_node; + ASTNode& m_function_expression; + public: void execute(ExecUntilBreakOrContinue& exec_policy) { - this->_executeArguments(exec_policy); - - ASTNode& definition_node = m_function_descriptor.definitionNode(); - ASTNode& function_expression = *definition_node.children[1]; - function_expression.execute(exec_policy); + m_function_expression.execute(exec_policy); if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { - m_node.m_value = function_expression.m_value; + m_node.m_value = m_function_expression.m_value; } else { - m_node.m_value = static_cast<ReturnType>(std::get<ExpressionValueType>(function_expression.m_value)); + m_node.m_value = static_cast<ReturnType>(std::get<ExpressionValueType>(m_function_expression.m_value)); } } FunctionExpressionProcessor(ASTNode& node) - : m_node{node}, m_function_descriptor{[&]() -> FunctionDescriptor& { + : m_node{node}, m_function_expression{[&]() -> ASTNode& { auto [i_symbol, found] = m_node.m_symbol_table->find(m_node.children[0]->string(), m_node.begin()); Assert(found); uint64_t function_id = std::get<uint64_t>(i_symbol->attributes().value()); - return m_node.m_symbol_table->functionTable()[function_id]; + FunctionDescriptor& function_descriptor = m_node.m_symbol_table->functionTable()[function_id]; + return *function_descriptor.definitionNode().children[1]; }()} {} };