#include <language/ast/ASTSymbolTableBuilder.hpp>

#include <language/PEGGrammar.hpp>
#include <language/utils/SymbolTable.hpp>

void
ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable>& symbol_table)
{
  if (n.is_type<language::block>() or (n.is_type<language::for_statement>())) {
    if (!n.children.empty()) {
      std::shared_ptr block_symbol_table = std::make_shared<SymbolTable>(symbol_table);
      n.m_symbol_table                   = block_symbol_table;

      for (auto& child : n.children) {
        this->buildSymbolTable(*child, block_symbol_table);
      }
    }
  } else if (n.is_type<language::fct_declaration>()) {
    std::shared_ptr local_symbol_table =
      std::make_shared<SymbolTable>(symbol_table, std::make_shared<SymbolTable::Context>());

    n.m_symbol_table          = local_symbol_table;
    const std::string& symbol = n.children[0]->string();
    auto [i_symbol, success]  = symbol_table->add(symbol, n.children[0]->begin());
    if (not success) {
      std::ostringstream error_message;
      error_message << "symbol '" << rang::fg::red << symbol << rang::fg::reset << "' was already defined!";
      throw parse_error(error_message.str(), std::vector{n.begin()});
    }

    for (auto& child : n.children) {
      this->buildSymbolTable(*child, local_symbol_table);
    }

    size_t function_id =
      symbol_table->functionTable().add(FunctionDescriptor{symbol, std::move(n.children[1]), std::move(n.children[2])});
    i_symbol->attributes().value() = function_id;
    n.children.resize(1);
  } else {
    n.m_symbol_table = symbol_table;
    if (n.has_content()) {
      if (n.is_type<language::var_declaration>()) {
        auto register_symbol = [&](const ASTNode& argument_node) {
          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          if (not success) {
            std::ostringstream error_message;
            error_message << "symbol '" << rang::fg::red << argument_node.string() << rang::fg::reset
                          << "' was already defined!";
            throw parse_error(error_message.str(), std::vector{argument_node.begin()});
          }
        };

        if (n.children[0]->is_type<language::name>()) {
          register_symbol(*n.children[0]);
        } else {   // treats the case of list of parameters
          Assert(n.children[0]->is_type<language::name_list>());
          for (auto& child : n.children[0]->children) {
            register_symbol(*child);
          }
        }
      } else if (n.is_type<language::function_definition>()) {
        auto register_and_initialize_symbol = [&](const ASTNode& argument_node) {
          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          if (not success) {
            std::ostringstream error_message;
            error_message << "symbol '" << rang::fg::red << argument_node.string() << rang::fg::reset
                          << "' was already defined!";
            throw parse_error(error_message.str(), std::vector{argument_node.begin()});
          }
          // Symbols will be initialized at call
          i_symbol->attributes().setIsInitialized();
        };

        if (n.children[0]->is_type<language::name>()) {
          register_and_initialize_symbol(*n.children[0]);
        } else {   // treats the case of list of parameters
          Assert(n.children[0]->is_type<language::name_list>());
          for (auto& child : n.children[0]->children) {
            register_and_initialize_symbol(*child);
          }
        }
      } else if (n.is_type<language::name>()) {
        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        if (not found) {
          std::ostringstream error_message;
          error_message << "undefined symbol '" << rang::fg::red << n.string() << rang::fg::reset << '\'';
          throw parse_error(error_message.str(), std::vector{n.begin()});
        }
      }
    }

    for (auto& child : n.children) {
      this->buildSymbolTable(*child, symbol_table);
    }
  }
}

ASTSymbolTableBuilder::ASTSymbolTableBuilder(ASTNode& node)
{
  Assert(node.is_root());

  this->buildSymbolTable(node, node.m_symbol_table);
  std::cout << " - checked symbols declaration\n";
}