#include <ASTNodeDataTypeBuilder.hpp>

#include <PEGGrammar.hpp>
#include <PugsAssert.hpp>
#include <SymbolTable.hpp>

#include <CFunctionEmbedder.hpp>

void
ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n)
{
  if (n.is_type<language::block>() or n.is_type<language::for_statement>()) {
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }
    n.m_data_type = ASTNodeDataType::void_t;
  } else {
    if (n.has_content()) {
      if (n.is_type<language::import_instruction>()) {
        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is_type<language::module_name>()) {
        n.m_data_type = ASTNodeDataType::string_t;

      } else if (n.is_type<language::true_kw>() or n.is_type<language::false_kw>()) {
        n.m_data_type = ASTNodeDataType::bool_t;
      } else if (n.is_type<language::real>()) {
        n.m_data_type = ASTNodeDataType::double_t;
      } else if (n.is_type<language::integer>()) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else if (n.is_type<language::literal>()) {
        n.m_data_type = ASTNodeDataType::string_t;
      } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) {
        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is_type<language::declaration>()) {
        auto& type_node = *(n.children[0]);
        ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
        if (type_node.is_type<language::B_set>()) {
          data_type = ASTNodeDataType::bool_t;
        } else if (type_node.is_type<language::Z_set>()) {
          data_type = ASTNodeDataType::int_t;
        } else if (type_node.is_type<language::N_set>()) {
          data_type = ASTNodeDataType::unsigned_int_t;
        } else if (type_node.is_type<language::R_set>()) {
          data_type = ASTNodeDataType::double_t;
        } else if (type_node.is_type<language::string_type>()) {
          data_type = ASTNodeDataType::string_t;
        }

        Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE

        type_node.m_data_type      = ASTNodeDataType::typename_t;
        n.children[1]->m_data_type = data_type;
        const std::string& symbol  = n.children[1]->string();

        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(symbol, n.children[1]->begin());
        Assert(found);
        i_symbol->attributes().setDataType(data_type);
        n.m_data_type = data_type;
      } else if (n.is_type<language::let_declaration>()) {
        n.children[0]->m_data_type = ASTNodeDataType::function_t;

        const std::string& symbol = n.children[0]->string();
        auto [i_symbol, success]  = n.m_symbol_table->find(symbol, n.children[0]->begin());

        auto& function_table = n.m_symbol_table->functionTable();

        uint64_t function_id                    = std::get<uint64_t>(i_symbol->attributes().value());
        FunctionDescriptor& function_descriptor = function_table[function_id];

        ASTNode& parameters_domain_node = *function_descriptor.domainMappingNode().children[0];
        ASTNode& parameters_name_node   = *function_descriptor.definitionNode().children[0];

        {   // Function data type
          const std::string& symbol = n.children[0]->string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, n.children[0]->begin());
          Assert(found);
          i_symbol->attributes().setDataType(n.children[0]->m_data_type);
        }

        const size_t nb_parameter_domains =
          (parameters_domain_node.children.size() > 0) ? parameters_domain_node.children.size() : 1;
        const size_t nb_parameter_names =
          (parameters_name_node.children.size() > 0) ? parameters_name_node.children.size() : 1;

        if (nb_parameter_domains != nb_parameter_names) {
          std::ostringstream message;
          message << "Compound data type deduction is not yet implemented\n"
                  << "note: number of product spaces (" << nb_parameter_domains << ") " << rang::fgB::yellow
                  << parameters_domain_node.string() << rang::style::reset << rang::style::bold
                  << " differs from number of variables (" << nb_parameter_names << ") " << rang::fgB::yellow
                  << parameters_name_node.string() << rang::style::reset << std::ends;
          throw parse_error(message.str(), n.children[0]->begin());
        }

        auto simple_type_allocator = [&](const ASTNode& type_node, ASTNode& symbol_node) {
          Assert(symbol_node.is_type<language::name>());
          ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
          if (type_node.is_type<language::B_set>()) {
            data_type = ASTNodeDataType::bool_t;
          } else if (type_node.is_type<language::Z_set>()) {
            data_type = ASTNodeDataType::int_t;
          } else if (type_node.is_type<language::N_set>()) {
            data_type = ASTNodeDataType::unsigned_int_t;
          } else if (type_node.is_type<language::R_set>()) {
            data_type = ASTNodeDataType::double_t;
          }

          if (data_type == ASTNodeDataType::undefined_t) {
            throw parse_error("invalid parameter type", type_node.begin());
          }

          symbol_node.m_data_type   = data_type;
          const std::string& symbol = symbol_node.string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, symbol_node.begin());
          Assert(found);
          i_symbol->attributes().setDataType(data_type);
        };

        if (parameters_domain_node.children.size() == 0) {
          simple_type_allocator(parameters_domain_node, parameters_name_node);
        } else {
          for (size_t i = 0; i < function_descriptor.domainMappingNode().children.size(); ++i) {
            simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]);
          }
        }

        // build types for compound types
        for (auto& child : parameters_domain_node.children) {
          this->_buildNodeDataTypes(*child);
        }
        for (auto& child : parameters_name_node.children) {
          this->_buildNodeDataTypes(*child);
        }

        ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];
        this->_buildNodeDataTypes(image_domain_node);
        for (auto& child : image_domain_node.children) {
          this->_buildNodeDataTypes(*child);
        }

        auto check_image_type = [&](const ASTNode& image_node) {
          ASTNodeDataType value_type{ASTNodeDataType::undefined_t};
          if (image_node.is_type<language::B_set>()) {
            value_type = ASTNodeDataType::bool_t;
          } else if (image_node.is_type<language::Z_set>()) {
            value_type = ASTNodeDataType::int_t;
          } else if (image_node.is_type<language::N_set>()) {
            value_type = ASTNodeDataType::unsigned_int_t;
          } else if (image_node.is_type<language::R_set>()) {
            value_type = ASTNodeDataType::double_t;
          }

          if (value_type == ASTNodeDataType::undefined_t) {
            throw parse_error("invalid value type", image_node.begin());
          }
        };

        if (image_domain_node.children.size() == 0) {
          check_image_type(image_domain_node);
        } else {
          for (size_t i = 0; i < image_domain_node.children.size(); ++i) {
            check_image_type(*image_domain_node.children[i]);
          }
        }

        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is_type<language::name>()) {
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        Assert(found);
        n.m_data_type = i_symbol->attributes().dataType();
      }
    }
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }

    if (n.is_type<language::break_kw>() or n.is_type<language::continue_kw>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::eq_op>() or n.is_type<language::multiplyeq_op>() or
               n.is_type<language::divideeq_op>() or n.is_type<language::pluseq_op>() or
               n.is_type<language::minuseq_op>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is_type<language::type_mapping>() or n.is_type<language::function_definition>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::for_post>() or n.is_type<language::for_init>() or
               n.is_type<language::for_statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::for_test>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is_type<language::statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::if_statement>() or n.is_type<language::while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;
      if ((n.children[0]->m_data_type > ASTNodeDataType::double_t) or
          (n.children[0]->m_data_type < ASTNodeDataType::bool_t)) {
        const ASTNodeDataType type_0 = n.children[0]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[0]->string() << "' of type " << dataTypeName(type_0)
                << std::ends;
        throw parse_error(message.str(), n.children[0]->begin());
      }
    } else if (n.is_type<language::do_while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;
      if ((n.children[1]->m_data_type > ASTNodeDataType::double_t) or
          (n.children[1]->m_data_type < ASTNodeDataType::bool_t)) {
        const ASTNodeDataType type_0 = n.children[1]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[1]->string() << "' of type " << dataTypeName(type_0)
                << std::ends;
        throw parse_error(message.str(), n.children[1]->begin());
      }
    } else if (n.is_type<language::unary_not>() or n.is_type<language::lesser_op>() or
               n.is_type<language::lesser_or_eq_op>() or n.is_type<language::greater_op>() or
               n.is_type<language::greater_or_eq_op>() or n.is_type<language::eqeq_op>() or
               n.is_type<language::not_eq_op>() or n.is_type<language::and_op>() or n.is_type<language::or_op>() or
               n.is_type<language::xor_op>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is_type<language::unary_minus>()) {
      n.m_data_type = n.children[0]->m_data_type;
      if ((n.children[0]->m_data_type == ASTNodeDataType::unsigned_int_t) or
          (n.children[0]->m_data_type == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = n.children[0]->m_data_type;
      }
    } else if (n.is_type<language::unary_plusplus>() or n.is_type<language::unary_minusminus>() or
               n.is_type<language::post_plusplus>() or n.is_type<language::post_minusminus>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is_type<language::plus_op>() or n.is_type<language::minus_op>() or
               n.is_type<language::multiply_op>() or n.is_type<language::divide_op>()) {
      const ASTNodeDataType type_0 = n.children[0]->m_data_type;
      const ASTNodeDataType type_1 = n.children[1]->m_data_type;
      if ((type_0 == ASTNodeDataType::bool_t) and (type_1 == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = dataTypePromotion(type_0, type_1);
      }
      if (n.m_data_type == ASTNodeDataType::undefined_t) {
        std::ostringstream message;
        message << "undefined binary operator\n"
                << "note: incompatible operand types " << n.children[0]->string() << " (" << dataTypeName(type_0)
                << ") and " << n.children[1]->string() << " (" << dataTypeName(type_1) << ')' << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    } else if (n.is_type<language::function_evaluation>()) {
      if (n.children[0]->m_data_type == ASTNodeDataType::function_t) {
        const std::string& function_name  = n.children[0]->string();
        auto [i_function_symbol, success] = n.m_symbol_table->find(function_name, n.children[0]->begin());

        auto& function_table = n.m_symbol_table->functionTable();

        uint64_t function_id                    = std::get<uint64_t>(i_function_symbol->attributes().value());
        FunctionDescriptor& function_descriptor = function_table[function_id];

        ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];

        Assert(image_domain_node.m_data_type == ASTNodeDataType::typename_t);

        ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
        if (image_domain_node.children.size() > 0) {
          throw parse_error("compound data type is not implemented yet", image_domain_node.begin());
        } else {
          if (image_domain_node.is_type<language::B_set>()) {
            data_type = ASTNodeDataType::bool_t;
          } else if (image_domain_node.is_type<language::Z_set>()) {
            data_type = ASTNodeDataType::int_t;
          } else if (image_domain_node.is_type<language::N_set>()) {
            data_type = ASTNodeDataType::unsigned_int_t;
          } else if (image_domain_node.is_type<language::R_set>()) {
            data_type = ASTNodeDataType::double_t;
          }
        }

        Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE

        n.m_data_type = data_type;
      } else if (n.children[0]->m_data_type == ASTNodeDataType::c_function_t) {
        const std::string c_function_name = n.children[0]->string();
        auto& symbol_table                = *n.m_symbol_table;

        auto [i_symbol, success] = symbol_table.find(c_function_name, n.begin());
        Assert(success);

        uint64_t c_function_id   = std::get<uint64_t>(i_symbol->attributes().value());
        auto c_function_embedder = symbol_table.cFunctionEbedderTable()[c_function_id];
        Assert(c_function_embedder);

        n.m_data_type = c_function_embedder->getReturnDataType();
      } else {
        std::ostringstream message;
        message << "invalid function call\n"
                << "note: '" << n.children[0]->string() << "' (type: " << dataTypeName(n.children[0]->m_data_type)
                << ") is not a function!" << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    } else if (n.is_type<language::B_set>() or n.is_type<language::Z_set>() or n.is_type<language::N_set>() or
               n.is_type<language::R_set>() or n.is_type<language::string_type>()) {
      n.m_data_type = ASTNodeDataType::typename_t;
    } else if (n.is_type<language::name_list>() or n.is_type<language::expression_list>() or
               n.is_type<language::function_argument_list>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    }
  }
}

ASTNodeDataTypeBuilder::ASTNodeDataTypeBuilder(ASTNode& node)
{
  Assert(node.is_root());
  node.m_data_type = ASTNodeDataType::void_t;

  this->_buildNodeDataTypes(node);

  FunctionTable& function_table = node.m_symbol_table->functionTable();
  for (size_t function_id = 0; function_id < function_table.size(); ++function_id) {
    FunctionDescriptor& function_descriptor = function_table[function_id];
    ASTNode& function_expression            = function_descriptor.definitionNode();

    this->_buildNodeDataTypes(function_expression);
  }

  std::cout << " - build node data types\n";
}
