#include <ASTNodeDataTypeBuilder.hpp>

#include <PEGGrammar.hpp>
#include <PugsAssert.hpp>
#include <SymbolTable.hpp>

void
ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n)
{
  if (n.is<language::block>() or n.is<language::for_statement>()) {
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }
    n.m_data_type = ASTNodeDataType::void_t;
  } else {
    if (n.has_content()) {
      if (n.is<language::true_kw>() or n.is<language::false_kw>()) {
        n.m_data_type = ASTNodeDataType::bool_t;
      } else if (n.is<language::real>()) {
        n.m_data_type = ASTNodeDataType::double_t;
      } else if (n.is<language::integer>()) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else if (n.is<language::literal>()) {
        n.m_data_type = ASTNodeDataType::string_t;
      } else if (n.is<language::cout_kw>() or n.is<language::cerr_kw>() or n.is<language::clog_kw>()) {
        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is<language::declaration>()) {
        auto& type_node = *(n.children[0]);
        ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
        if (type_node.is<language::B_set>()) {
          data_type = ASTNodeDataType::bool_t;
        } else if (type_node.is<language::Z_set>()) {
          data_type = ASTNodeDataType::int_t;
        } else if (type_node.is<language::N_set>()) {
          data_type = ASTNodeDataType::unsigned_int_t;
        } else if (type_node.is<language::R_set>()) {
          data_type = ASTNodeDataType::double_t;
        } else if (type_node.is<language::string_type>()) {
          data_type = ASTNodeDataType::string_t;
        }

        Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE

        type_node.m_data_type      = ASTNodeDataType::typename_t;
        n.children[1]->m_data_type = data_type;
        const std::string& symbol  = n.children[1]->string();

        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(symbol, n.children[1]->begin());
        Assert(found);
        i_symbol->attributes().setDataType(data_type);
        n.m_data_type = data_type;
      } else if (n.is<language::let_declaration>()) {
        n.children[0]->m_data_type = ASTNodeDataType::function_t;

        const std::string& symbol = n.children[0]->string();
        auto [i_symbol, success]  = n.m_symbol_table->find(symbol, n.children[0]->begin());

        auto& function_table = n.m_symbol_table->functionTable();

        uint64_t function_id                    = std::get<uint64_t>(i_symbol->attributes().value());
        FunctionDescriptor& function_descriptor = function_table[function_id];

        ASTNode& parameters_domain_node = *function_descriptor.domainMappingNode().children[0];
        ASTNode& parameters_name_node   = *function_descriptor.definitionNode().children[0];

        {   // Function data type
          const std::string& symbol = n.children[0]->string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, n.children[0]->begin());
          Assert(found);
          i_symbol->attributes().setDataType(n.children[0]->m_data_type);
        }

        if (parameters_domain_node.children.size() != parameters_name_node.children.size()) {
          std::ostringstream message;
          message << "Compound data type deduction is not yet implemented\n"
                  << "note: number of product spaces (" << parameters_domain_node.children.size() << ") "
                  << rang::fgB::yellow << parameters_domain_node.string() << rang::style::reset
                  << " differs  from number of variables (" << parameters_name_node.children.size() << ") "
                  << rang::fgB::yellow << parameters_name_node.string() << rang::style::reset << std::ends;
          throw parse_error(message.str(), n.children[0]->begin());
        }

        auto simple_type_allocator = [&](const ASTNode& type_node, ASTNode& symbol_node) {
          ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
          if (type_node.is<language::B_set>()) {
            data_type = ASTNodeDataType::bool_t;
          } else if (type_node.is<language::Z_set>()) {
            data_type = ASTNodeDataType::int_t;
          } else if (type_node.is<language::N_set>()) {
            data_type = ASTNodeDataType::unsigned_int_t;
          } else if (type_node.is<language::R_set>()) {
            data_type = ASTNodeDataType::double_t;
          } else if (type_node.is<language::string_type>()) {
            data_type = ASTNodeDataType::string_t;
          }

          Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE

          symbol_node.m_data_type   = data_type;
          const std::string& symbol = symbol_node.string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, symbol_node.begin());
          Assert(found);
          i_symbol->attributes().setDataType(data_type);
        };

        if (parameters_domain_node.children.size() == 0) {
          simple_type_allocator(parameters_domain_node, parameters_name_node);
        } else {
          for (size_t i = 0; i < function_descriptor.domainMappingNode().children.size(); ++i) {
            simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]);
          }
        }

        // build types for compound types
        for (auto& child : parameters_domain_node.children) {
          this->_buildNodeDataTypes(*child);
        }
        for (auto& child : parameters_name_node.children) {
          this->_buildNodeDataTypes(*child);
        }

        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is<language::name>()) {
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        Assert(found);
        n.m_data_type = i_symbol->attributes().dataType();
      }
    }
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }

    if (n.is<language::break_kw>() or n.is<language::continue_kw>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is<language::eq_op>() or n.is<language::multiplyeq_op>() or n.is<language::divideeq_op>() or
               n.is<language::pluseq_op>() or n.is<language::minuseq_op>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is<language::type_mapping>() or n.is<language::function_definition>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is<language::for_post>() or n.is<language::for_init>() or n.is<language::for_statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is<language::for_test>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is<language::statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is<language::if_statement>() or n.is<language::while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;
      if ((n.children[0]->m_data_type > ASTNodeDataType::double_t) or
          (n.children[0]->m_data_type < ASTNodeDataType::bool_t)) {
        const ASTNodeDataType type_0 = n.children[0]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[0]->string() << "' of type " << dataTypeName(type_0)
                << std::ends;
        throw parse_error(message.str(), n.children[0]->begin());
      }
    } else if (n.is<language::do_while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;
      if ((n.children[1]->m_data_type > ASTNodeDataType::double_t) or
          (n.children[1]->m_data_type < ASTNodeDataType::bool_t)) {
        const ASTNodeDataType type_0 = n.children[1]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[1]->string() << "' of type " << dataTypeName(type_0)
                << std::ends;
        throw parse_error(message.str(), n.children[1]->begin());
      }
    } else if (n.is<language::unary_not>() or n.is<language::lesser_op>() or n.is<language::lesser_or_eq_op>() or
               n.is<language::greater_op>() or n.is<language::greater_or_eq_op>() or n.is<language::eqeq_op>() or
               n.is<language::not_eq_op>() or n.is<language::and_op>() or n.is<language::or_op>() or
               n.is<language::xor_op>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is<language::unary_minus>()) {
      n.m_data_type = n.children[0]->m_data_type;
      if ((n.children[0]->m_data_type == ASTNodeDataType::unsigned_int_t) or
          (n.children[0]->m_data_type == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = n.children[0]->m_data_type;
      }
    } else if (n.is<language::unary_plusplus>() or n.is<language::unary_minusminus>() or
               n.is<language::post_plusplus>() or n.is<language::post_minusminus>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is<language::plus_op>() or n.is<language::minus_op>() or n.is<language::multiply_op>() or
               n.is<language::divide_op>()) {
      const ASTNodeDataType type_0 = n.children[0]->m_data_type;
      const ASTNodeDataType type_1 = n.children[1]->m_data_type;
      if ((type_0 == ASTNodeDataType::bool_t) and (type_1 == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = dataTypePromotion(type_0, type_1);
      }
      if (n.m_data_type == ASTNodeDataType::undefined_t) {
        std::ostringstream message;
        message << "undefined binary operator\n"
                << "note: incompatible operand types " << n.children[0]->string() << " (" << dataTypeName(type_0)
                << ") and " << n.children[1]->string() << " (" << dataTypeName(type_1) << ')' << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    } else if (n.is<language::function_evaluation>()) {
      if (n.children[0]->m_data_type != ASTNodeDataType::function_t) {
        std::ostringstream message;
        message << "invalid function call\n"
                << "note: '" << n.children[0]->string() << "' is not a function!" << std::ends;
        throw parse_error(message.str(), n.begin());
      }
      std::cout << rang::fgB::red << "returned type of function evaluation is incorrect" << rang::style::reset << "\n";
      n.m_data_type = ASTNodeDataType::double_t;
    } else if (n.is<language::B_set>() or n.is<language::Z_set>() or n.is<language::N_set>() or
               n.is<language::R_set>() or n.is<language::string_type>()) {
      n.m_data_type = ASTNodeDataType::typename_t;
    } else if (n.is<language::name_list>() or n.is<language::expression_list>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    }
  }
}

ASTNodeDataTypeBuilder::ASTNodeDataTypeBuilder(ASTNode& node)
{
  Assert(node.is_root());
  node.m_data_type = ASTNodeDataType::void_t;

  this->_buildNodeDataTypes(node);
  std::cout << " - build node data types\n";
}
