#include <language/ast/ASTNodeDataTypeBuilder.hpp>

#include <language/PEGGrammar.hpp>
#include <language/ast/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/PugsAssert.hpp>

ASTNodeDataType
ASTNodeDataTypeBuilder::_buildDeclarationNodeDataTypes(ASTNode& type_node, ASTNode& name_node) const
{
  ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
  if (type_node.is_type<language::type_expression>()) {
    if (type_node.children.size() != name_node.children.size()) {
      std::ostringstream message;
      message << "number of product spaces (" << type_node.children.size() << ") " << rang::fgB::yellow
              << type_node.string() << rang::style::reset << rang::style::bold << " differs from number of variables ("
              << name_node.children.size() << ") " << rang::fgB::yellow << name_node.string() << rang::style::reset
              << std::ends;
      throw parse_error(message.str(), name_node.begin());
    }

    for (size_t i = 0; i < type_node.children.size(); ++i) {
      auto& sub_type_node = *type_node.children[i];
      auto& sub_name_node = *name_node.children[i];
      _buildDeclarationNodeDataTypes(sub_type_node, sub_name_node);
    }
    data_type = ASTNodeDataType::typename_t;
  } else {
    if (type_node.is_type<language::B_set>()) {
      data_type = ASTNodeDataType::bool_t;
    } else if (type_node.is_type<language::Z_set>()) {
      data_type = ASTNodeDataType::int_t;
    } else if (type_node.is_type<language::N_set>()) {
      data_type = ASTNodeDataType::unsigned_int_t;
    } else if (type_node.is_type<language::R_set>()) {
      data_type = ASTNodeDataType::double_t;
    } else if (type_node.is_type<language::vector_type>()) {
      data_type = getVectorDataType(type_node);
    } else if (type_node.is_type<language::tuple_type_specifier>()) {
      const auto& content_node = type_node.children[0];

      if (content_node->is_type<language::type_name_id>()) {
        const std::string& type_name_id = content_node->string();

        auto& symbol_table = *type_node.m_symbol_table;

        const auto [i_type_symbol, found] = symbol_table.find(type_name_id, content_node->begin());
        if (not found) {
          throw parse_error("undefined type identifier", std::vector{content_node->begin()});
        } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) {
          std::ostringstream os;
          os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '"
             << dataTypeName(i_type_symbol->attributes().dataType()) << "'" << std::ends;
          throw parse_error(os.str(), std::vector{content_node->begin()});
        }

        content_node->m_data_type = ASTNodeDataType{ASTNodeDataType::type_id_t, type_name_id};
      } else if (content_node->is_type<language::B_set>()) {
        content_node->m_data_type = ASTNodeDataType::bool_t;
      } else if (content_node->is_type<language::Z_set>()) {
        content_node->m_data_type = ASTNodeDataType::int_t;
      } else if (content_node->is_type<language::N_set>()) {
        content_node->m_data_type = ASTNodeDataType::unsigned_int_t;
      } else if (content_node->is_type<language::R_set>()) {
        content_node->m_data_type = ASTNodeDataType::double_t;
      } else if (content_node->is_type<language::vector_type>()) {
        content_node->m_data_type = getVectorDataType(type_node);
      } else if (content_node->is_type<language::string_type>()) {
        content_node->m_data_type = ASTNodeDataType::string_t;
      } else {
        throw UnexpectedError("unexpected content type in tuple");
      }

      data_type = ASTNodeDataType{ASTNodeDataType::tuple_t, content_node->m_data_type};
    } else if (type_node.is_type<language::string_type>()) {
      data_type = ASTNodeDataType::string_t;
    } else if (type_node.is_type<language::type_name_id>()) {
      const std::string& type_name_id = type_node.string();

      auto& symbol_table = *type_node.m_symbol_table;

      auto [i_type_symbol, found] = symbol_table.find(type_name_id, type_node.begin());
      if (not found) {
        throw parse_error("undefined type identifier", std::vector{type_node.begin()});
      } else if (i_type_symbol->attributes().dataType() != ASTNodeDataType::type_name_id_t) {
        std::ostringstream os;
        os << "invalid type identifier, '" << type_name_id << "' was previously defined as a '"
           << dataTypeName(i_type_symbol->attributes().dataType()) << "'" << std::ends;
        throw parse_error(os.str(), std::vector{type_node.begin()});
      }

      data_type = ASTNodeDataType{ASTNodeDataType::type_id_t, type_name_id};
    }

    if (name_node.is_type<language::name_list>()) {
      throw parse_error("unexpected variable list for single space", std::vector{name_node.begin()});
    }

    Assert(name_node.is_type<language::name>());
    name_node.m_data_type = data_type;

    const std::string& symbol = name_node.string();

    std::shared_ptr<SymbolTable>& symbol_table = name_node.m_symbol_table;

    auto [i_symbol, found] = symbol_table->find(symbol, name_node.begin());
    Assert(found);
    i_symbol->attributes().setDataType(name_node.m_data_type);
  }

  Assert(data_type != ASTNodeDataType::undefined_t);
  return data_type;
}

void
ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
{
  if (n.is_type<language::block>() or n.is_type<language::for_statement>()) {
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }

    if (n.is_type<language::for_statement>()) {
      const ASTNode& test_node = *n.children[1];

      if (not n.children[1]->is_type<language::for_test>()) {
        ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t};
      }   // in the case of empty for_test (not simplified node), nothing to check!
    }

    n.m_data_type = ASTNodeDataType::void_t;
  } else {
    if (n.has_content()) {
      if (n.is_type<language::import_instruction>()) {
        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is_type<language::module_name>()) {
        n.m_data_type = ASTNodeDataType::string_t;

      } else if (n.is_type<language::true_kw>() or n.is_type<language::false_kw>()) {
        n.m_data_type = ASTNodeDataType::bool_t;
      } else if (n.is_type<language::real>()) {
        n.m_data_type = ASTNodeDataType::double_t;
      } else if (n.is_type<language::integer>()) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else if (n.is_type<language::vector_type>()) {
        n.m_data_type = getVectorDataType(n);

      } else if (n.is_type<language::tuple_expression>()) {
        n.m_data_type = ASTNodeDataType::list_t;

      } else if (n.is_type<language::literal>()) {
        n.m_data_type = ASTNodeDataType::string_t;
      } else if (n.is_type<language::cout_kw>() or n.is_type<language::cerr_kw>() or n.is_type<language::clog_kw>()) {
        n.m_data_type = ASTNodeDataType::void_t;
      } else if (n.is_type<language::var_declaration>()) {
        auto& name_node = *(n.children[0]);
        auto& type_node = *(n.children[1]);

        type_node.m_data_type = _buildDeclarationNodeDataTypes(type_node, name_node);
        n.m_data_type         = type_node.m_data_type;
      } else if (n.is_type<language::fct_declaration>()) {
        n.children[0]->m_data_type = ASTNodeDataType::function_t;

        const std::string& symbol = n.children[0]->string();
        auto [i_symbol, success]  = n.m_symbol_table->find(symbol, n.children[0]->begin());

        auto& function_table = n.m_symbol_table->functionTable();

        uint64_t function_id                    = std::get<uint64_t>(i_symbol->attributes().value());
        FunctionDescriptor& function_descriptor = function_table[function_id];

        ASTNode& parameters_domain_node = *function_descriptor.domainMappingNode().children[0];
        ASTNode& parameters_name_node   = *function_descriptor.definitionNode().children[0];

        {   // Function data type
          const std::string& symbol = n.children[0]->string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, n.children[0]->begin());
          Assert(found);
          i_symbol->attributes().setDataType(n.children[0]->m_data_type);
        }

        const size_t nb_parameter_domains =
          (parameters_domain_node.is_type<language::type_expression>()) ? parameters_domain_node.children.size() : 1;
        const size_t nb_parameter_names =
          (parameters_name_node.children.size() > 0) ? parameters_name_node.children.size() : 1;

        if (nb_parameter_domains != nb_parameter_names) {
          std::ostringstream message;
          message << "note: number of product spaces (" << nb_parameter_domains << ") " << rang::fgB::yellow
                  << parameters_domain_node.string() << rang::style::reset << rang::style::bold
                  << " differs from number of variables (" << nb_parameter_names << ") " << rang::fgB::yellow
                  << parameters_name_node.string() << rang::style::reset << std::ends;
          throw parse_error(message.str(), parameters_domain_node.begin());
        }

        auto simple_type_allocator = [&](const ASTNode& type_node, ASTNode& symbol_node) {
          Assert(symbol_node.is_type<language::name>());
          ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
          if (type_node.is_type<language::B_set>()) {
            data_type = ASTNodeDataType::bool_t;
          } else if (type_node.is_type<language::Z_set>()) {
            data_type = ASTNodeDataType::int_t;
          } else if (type_node.is_type<language::N_set>()) {
            data_type = ASTNodeDataType::unsigned_int_t;
          } else if (type_node.is_type<language::R_set>()) {
            data_type = ASTNodeDataType::double_t;
          } else if (type_node.is_type<language::vector_type>()) {
            data_type = getVectorDataType(type_node);
          } else if (type_node.is_type<language::string_type>()) {
            data_type = ASTNodeDataType::string_t;
          }

          // LCOV_EXCL_START
          if (data_type == ASTNodeDataType::undefined_t) {
            throw parse_error("invalid parameter type", type_node.begin());
          }
          // LCOV_EXCL_STOP

          symbol_node.m_data_type   = data_type;
          const std::string& symbol = symbol_node.string();

          std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

          auto [i_symbol, found] = symbol_table->find(symbol, symbol_node.begin());
          Assert(found);
          i_symbol->attributes().setDataType(data_type);
        };

        if (nb_parameter_domains == 1) {
          simple_type_allocator(parameters_domain_node, parameters_name_node);
        } else {
          for (size_t i = 0; i < nb_parameter_domains; ++i) {
            simple_type_allocator(*parameters_domain_node.children[i], *parameters_name_node.children[i]);
          }
          parameters_name_node.m_data_type = ASTNodeDataType::list_t;
        }

        // build types for compound types
        for (auto& child : parameters_domain_node.children) {
          this->_buildNodeDataTypes(*child);
        }
        for (auto& child : parameters_name_node.children) {
          this->_buildNodeDataTypes(*child);
        }

        ASTNode& image_domain_node     = *function_descriptor.domainMappingNode().children[1];
        ASTNode& image_expression_node = *function_descriptor.definitionNode().children[1];

        this->_buildNodeDataTypes(image_domain_node);
        for (auto& child : image_domain_node.children) {
          this->_buildNodeDataTypes(*child);
        }

        const size_t nb_image_domains =
          (image_domain_node.is_type<language::type_expression>()) ? image_domain_node.children.size() : 1;
        const size_t nb_image_expressions =
          (image_expression_node.is_type<language::expression_list>()) ? image_expression_node.children.size() : 1;

        if (nb_image_domains != nb_image_expressions) {
          if (image_domain_node.is_type<language::vector_type>()) {
            ASTNodeDataType image_type = getVectorDataType(image_domain_node);
            if (image_type.dimension() != nb_image_expressions) {
              std::ostringstream message;
              message << "expecting " << image_type.dimension() << " scalar expressions or an "
                      << dataTypeName(image_type) << ", found " << nb_image_expressions << " scalar expressions"
                      << std::ends;
              throw parse_error(message.str(), image_domain_node.begin());
            }
          } else {
            std::ostringstream message;
            message << "number of image spaces (" << nb_image_domains << ") " << rang::fgB::yellow
                    << image_domain_node.string() << rang::style::reset << rang::style::bold
                    << " differs from number of expressions (" << nb_image_expressions << ") " << rang::fgB::yellow
                    << image_expression_node.string() << rang::style::reset << std::ends;
            throw parse_error(message.str(), image_domain_node.begin());
          }
        }

        auto check_image_type = [&](const ASTNode& image_node) {
          ASTNodeDataType value_type{ASTNodeDataType::undefined_t};
          if (image_node.is_type<language::B_set>()) {
            value_type = ASTNodeDataType::bool_t;
          } else if (image_node.is_type<language::Z_set>()) {
            value_type = ASTNodeDataType::int_t;
          } else if (image_node.is_type<language::N_set>()) {
            value_type = ASTNodeDataType::unsigned_int_t;
          } else if (image_node.is_type<language::R_set>()) {
            value_type = ASTNodeDataType::double_t;
          } else if (image_node.is_type<language::vector_type>()) {
            value_type = getVectorDataType(image_node);
          } else if (image_node.is_type<language::string_type>()) {
            value_type = ASTNodeDataType::string_t;
          }

          // LCOV_EXCL_START
          if (value_type == ASTNodeDataType::undefined_t) {
            throw parse_error("invalid value type", image_node.begin());
          }
          // LCOV_EXCL_STOP
        };

        if (image_domain_node.is_type<language::type_expression>()) {
          for (size_t i = 0; i < image_domain_node.children.size(); ++i) {
            check_image_type(*image_domain_node.children[i]);
          }
          image_domain_node.m_data_type = ASTNodeDataType::typename_t;
        } else {
          check_image_type(image_domain_node);
        }

        n.m_data_type = ASTNodeDataType::void_t;
        return;
      } else if (n.is_type<language::name>()) {
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        Assert(found);
        n.m_data_type = i_symbol->attributes().dataType();
      }
    }
    for (auto& child : n.children) {
      this->_buildNodeDataTypes(*child);
    }

    if (n.is_type<language::break_kw>() or n.is_type<language::continue_kw>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::eq_op>() or n.is_type<language::multiplyeq_op>() or
               n.is_type<language::divideeq_op>() or n.is_type<language::pluseq_op>() or
               n.is_type<language::minuseq_op>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is_type<language::type_mapping>() or n.is_type<language::function_definition>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::for_post>() or n.is_type<language::for_init>() or
               n.is_type<language::for_statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::for_test>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is_type<language::statement_block>()) {
      n.m_data_type = ASTNodeDataType::void_t;
    } else if (n.is_type<language::if_statement>() or n.is_type<language::while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;

      const ASTNode& test_node = *n.children[0];
      ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t};

    } else if (n.is_type<language::do_while_statement>()) {
      n.m_data_type = ASTNodeDataType::void_t;

      const ASTNode& test_node = *n.children[1];
      ASTNodeNaturalConversionChecker{test_node, ASTNodeDataType::bool_t};

    } else if (n.is_type<language::unary_not>()) {
      n.m_data_type = ASTNodeDataType::bool_t;

      const ASTNode& operand_node = *n.children[0];
      ASTNodeNaturalConversionChecker{operand_node, ASTNodeDataType::bool_t};

    } else if (n.is_type<language::lesser_op>() or n.is_type<language::lesser_or_eq_op>() or
               n.is_type<language::greater_op>() or n.is_type<language::greater_or_eq_op>() or
               n.is_type<language::eqeq_op>() or n.is_type<language::not_eq_op>()) {
      n.m_data_type = ASTNodeDataType::bool_t;
    } else if (n.is_type<language::and_op>() or n.is_type<language::or_op>() or n.is_type<language::xor_op>()) {
      n.m_data_type = ASTNodeDataType::bool_t;

      const ASTNode& lhs_node = *n.children[0];
      ASTNodeNaturalConversionChecker{lhs_node, ASTNodeDataType::bool_t};

      const ASTNode& rhs_node = *n.children[1];
      ASTNodeNaturalConversionChecker{rhs_node, ASTNodeDataType::bool_t};

    } else if (n.is_type<language::unary_minus>()) {
      n.m_data_type = n.children[0]->m_data_type;
      if ((n.children[0]->m_data_type == ASTNodeDataType::unsigned_int_t) or
          (n.children[0]->m_data_type == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = n.children[0]->m_data_type;
      }
    } else if (n.is_type<language::unary_plusplus>() or n.is_type<language::unary_minusminus>() or
               n.is_type<language::post_plusplus>() or n.is_type<language::post_minusminus>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is_type<language::plus_op>() or n.is_type<language::minus_op>() or
               n.is_type<language::multiply_op>() or n.is_type<language::divide_op>()) {
      const ASTNodeDataType type_0 = n.children[0]->m_data_type;
      const ASTNodeDataType type_1 = n.children[1]->m_data_type;
      if ((type_0 == ASTNodeDataType::bool_t) and (type_1 == ASTNodeDataType::bool_t)) {
        n.m_data_type = ASTNodeDataType::int_t;
      } else {
        n.m_data_type = dataTypePromotion(type_0, type_1);
      }
      if (n.m_data_type == ASTNodeDataType::undefined_t) {
        std::ostringstream message;
        message << "undefined binary operator\n"
                << "note: incompatible operand types " << n.children[0]->string() << " (" << dataTypeName(type_0)
                << ") and " << n.children[1]->string() << " (" << dataTypeName(type_1) << ')' << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    } else if (n.is_type<language::function_evaluation>()) {
      if (n.children[0]->m_data_type == ASTNodeDataType::function_t) {
        const std::string& function_name  = n.children[0]->string();
        auto [i_function_symbol, success] = n.m_symbol_table->find(function_name, n.children[0]->begin());

        auto& function_table = n.m_symbol_table->functionTable();

        uint64_t function_id                    = std::get<uint64_t>(i_function_symbol->attributes().value());
        FunctionDescriptor& function_descriptor = function_table[function_id];

        ASTNode& image_domain_node = *function_descriptor.domainMappingNode().children[1];

        ASTNodeDataType data_type{ASTNodeDataType::undefined_t};
        if (image_domain_node.is_type<language::type_expression>()) {
          data_type = image_domain_node.m_data_type;
        } else {
          if (image_domain_node.is_type<language::B_set>()) {
            data_type = ASTNodeDataType::bool_t;
          } else if (image_domain_node.is_type<language::Z_set>()) {
            data_type = ASTNodeDataType::int_t;
          } else if (image_domain_node.is_type<language::N_set>()) {
            data_type = ASTNodeDataType::unsigned_int_t;
          } else if (image_domain_node.is_type<language::R_set>()) {
            data_type = ASTNodeDataType::double_t;
          } else if (image_domain_node.is_type<language::vector_type>()) {
            data_type = getVectorDataType(image_domain_node);
          } else if (image_domain_node.is_type<language::string_type>()) {
            data_type = ASTNodeDataType::string_t;
          }
        }

        Assert(data_type != ASTNodeDataType::undefined_t);   // LCOV_EXCL_LINE

        n.m_data_type = data_type;
      } else if (n.children[0]->m_data_type == ASTNodeDataType::builtin_function_t) {
        const std::string builtin_function_name = n.children[0]->string();
        auto& symbol_table                      = *n.m_symbol_table;

        auto [i_symbol, success] = symbol_table.find(builtin_function_name, n.begin());
        Assert(success);

        uint64_t builtin_function_id   = std::get<uint64_t>(i_symbol->attributes().value());
        auto builtin_function_embedder = symbol_table.builtinFunctionEmbedderTable()[builtin_function_id];
        Assert(builtin_function_embedder);

        n.m_data_type = builtin_function_embedder->getReturnDataType();
      } else {
        std::ostringstream message;
        message << "invalid function call\n"
                << "note: '" << n.children[0]->string() << "' (type: " << dataTypeName(n.children[0]->m_data_type)
                << ") is not a function!" << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    } else if (n.is_type<language::subscript_expression>()) {
      Assert(n.children.size() == 2, "invalid number of sub-expressions in array subscript expression");
      auto& array_expression = *n.children[0];
      auto& index_expression = *n.children[1];

      ASTNodeNaturalConversionChecker{index_expression, ASTNodeDataType::int_t};
      if (array_expression.m_data_type != ASTNodeDataType::vector_t) {
        std::ostringstream message;
        message << "invalid types '" << rang::fgB::yellow << dataTypeName(array_expression.m_data_type)
                << rang::style::reset << '[' << dataTypeName(index_expression.m_data_type) << ']'
                << "' for array subscript" << std::ends;

        throw parse_error(message.str(), n.begin());
      } else {
        n.m_data_type = ASTNodeDataType::double_t;
      }
    } else if (n.is_type<language::B_set>() or n.is_type<language::Z_set>() or n.is_type<language::N_set>() or
               n.is_type<language::R_set>() or n.is_type<language::string_type>() or
               n.is_type<language::vector_type>() or n.is_type<language::type_name_id>()) {
      n.m_data_type = ASTNodeDataType::typename_t;
    } else if (n.is_type<language::name_list>() or n.is_type<language::lvalue_list>() or
               n.is_type<language::function_argument_list>() or n.is_type<language::expression_list>()) {
      n.m_data_type = ASTNodeDataType::list_t;
    }
  }
}

ASTNodeDataTypeBuilder::ASTNodeDataTypeBuilder(ASTNode& node)
{
  Assert(node.is_root());
  node.m_data_type = ASTNodeDataType::void_t;

  this->_buildNodeDataTypes(node);

  FunctionTable& function_table = node.m_symbol_table->functionTable();
  for (size_t function_id = 0; function_id < function_table.size(); ++function_id) {
    FunctionDescriptor& function_descriptor = function_table[function_id];
    ASTNode& function_expression            = function_descriptor.definitionNode();

    this->_buildNodeDataTypes(function_expression);
  }

  std::cout << " - build node data types\n";
}
