diff --git a/src/language/ASTNodeDataTypeBuilder.cpp b/src/language/ASTNodeDataTypeBuilder.cpp index 0a6223c4f61124c79c35d6df990a09245f927ed2..7f328daf18a7b3f72de0e4b8036cf94bcc675954 100644 --- a/src/language/ASTNodeDataTypeBuilder.cpp +++ b/src/language/ASTNodeDataTypeBuilder.cpp @@ -54,15 +54,16 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) } else if (n.is<language::let_declaration>()) { n.children[0]->m_data_type = ASTNodeDataType::function_t; - n.children[1]->children[0]->m_data_type = ASTNodeDataType::typename_t; - n.children[1]->children[1]->m_data_type = ASTNodeDataType::typename_t; - // build types for compound types - for (auto& child : n.children[1]->children[0]->children) { - this->_buildNodeDataTypes(*child); - } - for (auto& child : n.children[1]->children[1]->children) { - this->_buildNodeDataTypes(*child); - } + const std::string& symbol = n.children[0]->string(); + auto [i_symbol, success] = n.m_symbol_table->find(symbol, n.children[0]->begin()); + + auto& function_table = n.m_symbol_table->functionTable(); + + uint64_t function_id = std::get<uint64_t>(i_symbol->attributes().value()); + FunctionDescriptor& function_descriptor = function_table[function_id]; + + ASTNode& parameters_domain_node = *function_descriptor.domainMappingNode().children[0]; + ASTNode& parameters_name_node = *function_descriptor.definitionNode().children[0]; { // Function data type const std::string& symbol = n.children[0]->string(); @@ -74,13 +75,13 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) i_symbol->attributes().setDataType(n.children[0]->m_data_type); } - if (n.children[1]->children[0]->children.size() != n.children[2]->children[0]->children.size()) { + if (parameters_domain_node.children.size() != parameters_name_node.children.size()) { std::ostringstream message; message << "Compound data type deduction is not yet implemented\n" - << "note: number of product spaces (" << n.children[1]->children[0]->children.size() << ") " - << rang::fgB::yellow << n.children[1]->children[0]->string() << rang::style::reset - << " differs from number of variables (" << n.children[2]->children[0]->children.size() << ") " - << rang::fgB::yellow << n.children[2]->children[0]->string() << rang::style::reset << std::ends; + << "note: number of product spaces (" << parameters_domain_node.children.size() << ") " + << rang::fgB::yellow << parameters_domain_node.string() << rang::style::reset + << " differs from number of variables (" << parameters_name_node.children.size() << ") " + << rang::fgB::yellow << parameters_name_node.string() << rang::style::reset << std::ends; throw parse_error(message.str(), n.children[0]->begin()); } @@ -110,14 +111,22 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) i_symbol->attributes().setDataType(data_type); }; - if (n.children[1]->children[0]->children.size() == 0) { - simple_type_allocator(*n.children[1]->children[0], *n.children[2]->children[0]); + if (parameters_domain_node.children.size() == 0) { + simple_type_allocator(parameters_domain_node, parameters_name_node); } else { - for (size_t i = 0; i < n.children[1]->children[0]->children.size(); ++i) { - simple_type_allocator(*n.children[1]->children[0]->children[i], *n.children[2]->children[0]->children[i]); + for (size_t i = 0; i < function_descriptor.domainMappingNode().children.size(); ++i) { + simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]); } } + // build types for compound types + for (auto& child : parameters_domain_node.children) { + this->_buildNodeDataTypes(*child); + } + for (auto& child : parameters_name_node.children) { + this->_buildNodeDataTypes(*child); + } + n.m_data_type = ASTNodeDataType::void_t; } else if (n.is<language::name>()) { std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table; diff --git a/src/language/ASTSymbolInitializationChecker.cpp b/src/language/ASTSymbolInitializationChecker.cpp index c111a95a7aaf3f5f82a85a9318d67fc0f3e9d362..a461d3e6296416008197e2ca778f54567bb6014d 100644 --- a/src/language/ASTSymbolInitializationChecker.cpp +++ b/src/language/ASTSymbolInitializationChecker.cpp @@ -19,10 +19,15 @@ ASTSymbolInitializationChecker::_checkSymbolInitialization(ASTNode& node) const std::string& symbol = node.children[0]->string(); auto [i_symbol, found] = node.m_symbol_table->find(symbol, node.children[0]->begin()); Assert(found, "unexpected error, should have been detected through declaration checking"); - if (node.children.size() == 3) { - this->_checkSymbolInitialization(*node.children[2]); - i_symbol->attributes().setIsInitialized(); - } + + i_symbol->attributes().setIsInitialized(); + + auto& function_table = node.m_symbol_table->functionTable(); + + uint64_t function_id = std::get<uint64_t>(i_symbol->attributes().value()); + auto& function_descriptor = function_table[function_id]; + this->_checkSymbolInitialization(function_descriptor.definitionNode()); + } else if (node.is<language::function_definition>()) { const std::string& symbol = node.children[0]->string(); auto [i_symbol, found] = node.m_symbol_table->find(symbol, node.children[0]->begin()); @@ -47,7 +52,7 @@ ASTSymbolInitializationChecker::_checkSymbolInitialization(ASTNode& node) } } - if ((not node.is<language::declaration>()) and (not node.is<language::eq_op>())) { + if (not(node.is<language::declaration>() or node.is<language::let_declaration>() or node.is<language::eq_op>())) { for (auto& child : node.children) { this->_checkSymbolInitialization(*child); } diff --git a/src/language/ASTSymbolTableBuilder.cpp b/src/language/ASTSymbolTableBuilder.cpp index aaab7c1413d3f59c63533e589942b5ce38f64a3b..9220c662210f1c97a4f556a0a14592c41e956a84 100644 --- a/src/language/ASTSymbolTableBuilder.cpp +++ b/src/language/ASTSymbolTableBuilder.cpp @@ -30,9 +30,10 @@ ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable> this->buildSymbolTable(*child, local_symbol_table); } - size_t function_id = symbol_table->functionTable()->add(FunctionDescriptor{}); + size_t function_id = + symbol_table->functionTable().add(FunctionDescriptor{std::move(n.children[1]), std::move(n.children[2])}); i_symbol->attributes().value() = function_id; - + n.children.resize(1); } else { n.m_symbol_table = symbol_table; if (n.has_content()) { diff --git a/src/language/FunctionTable.hpp b/src/language/FunctionTable.hpp index a5dae9c85d841e73d098da0529785de4cac23473..700439fff4d6d64f38af0c7d95e9f169d6a89507 100644 --- a/src/language/FunctionTable.hpp +++ b/src/language/FunctionTable.hpp @@ -15,13 +15,27 @@ class FunctionDescriptor { std::unique_ptr<ASTNode> m_domain_mapping_node; - std::unique_ptr<ASTNode> m_expression_node; + std::unique_ptr<ASTNode> m_definition_node; public: + auto& + domainMappingNode() + { + Assert(m_domain_mapping_node, "undefined domain mapping node"); + return *m_domain_mapping_node; + } + + auto& + definitionNode() + { + Assert(m_domain_mapping_node, "undefined expression node"); + return *m_definition_node; + } + FunctionDescriptor& operator=(FunctionDescriptor&&) = default; - FunctionDescriptor(std::unique_ptr<ASTNode>&& domain_mapping_node, std::unique_ptr<ASTNode>&& expression_node) - : m_domain_mapping_node(std::move(domain_mapping_node)), m_expression_node(std::move(expression_node)) + FunctionDescriptor(std::unique_ptr<ASTNode>&& domain_mapping_node, std::unique_ptr<ASTNode>&& definition_node) + : m_domain_mapping_node(std::move(domain_mapping_node)), m_definition_node(std::move(definition_node)) {} FunctionDescriptor(FunctionDescriptor&&) = default; diff --git a/src/language/SymbolTable.hpp b/src/language/SymbolTable.hpp index af700706fe3955520da4117ff453685b7a24edcd..c8430916b58b276455e49be568c066adc032301b 100644 --- a/src/language/SymbolTable.hpp +++ b/src/language/SymbolTable.hpp @@ -136,10 +136,16 @@ class SymbolTable std::shared_ptr<FunctionTable> m_function_table; public: - auto + const FunctionTable& functionTable() const { - return m_function_table; + return *m_function_table; + } + + FunctionTable& + functionTable() + { + return *m_function_table; } friend std::ostream&