#include <language/ast/ASTBuilder.hpp>

using namespace TAO_PEGTL_NAMESPACE;

#include <language/PEGGrammar.hpp>
#include <language/ast/ASTNode.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/PugsAssert.hpp>

#include <pegtl/contrib/parse_tree.hpp>

struct ASTBuilder::rearrange : parse_tree::apply<ASTBuilder::rearrange>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    } else {
      // First we rearrange tree
      {
        auto& children = n->children;
        auto rhs       = std::move(children.back());
        children.pop_back();
        auto op = std::move(children.back());
        children.pop_back();
        op->children.emplace_back(std::move(n));
        op->children.emplace_back(std::move(rhs));
        n = std::move(op);
        transform(n->children.front(), st...);
      }
      // Then we eventually simplify operations
      {
        if (n->is_type<language::minus_op>()) {
          Assert(n->children.size() == 2);
          auto& rhs = n->children[1];
          if (rhs->is_type<language::unary_minus>()) {
            n->set_type<language::plus_op>();
            rhs = std::move(rhs->children[0]);
          }
        } else if (n->is_type<language::plus_op>()) {
          Assert(n->children.size() == 2);
          auto& rhs = n->children[1];
          if (rhs->is_type<language::unary_minus>()) {
            n->set_type<language::minus_op>();
            rhs = std::move(rhs->children[0]);
          }
        }
      }
    }
  }
};

struct ASTBuilder::simplify_unary : parse_tree::apply<ASTBuilder::simplify_unary>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->children.size() == 1) {
      if (n->is_type<language::unary_expression>() or n->is_type<language::type_expression>() or
          n->is_type<language::name_subscript_expression>()) {
        n = std::move(n->children.back());
        transform(n, st...);
      } else if (n->is_type<language::unary_minus>()) {
        auto& child = n->children[0];
        if (child->is_type<language::unary_minus>()) {
          n = std::move(child->children[0]);
          transform(n, st...);
        }
      } else if (n->is_type<language::unary_not>()) {
        auto& child = n->children[0];
        if (child->is_type<language::unary_not>()) {
          n = std::move(child->children[0]);
          transform(n, st...);
        }
      }
    } else if (n->children.size() == 2) {
      if (n->children[0]->is_type<language::unary_plus>()) {
        n = std::move(n->children[1]);
        transform(n, st...);
      } else if (n->children[0]->is_type<language::unary_minus>() or n->children[0]->is_type<language::unary_not>() or
                 n->children[0]->is_type<language::unary_minusminus>() or
                 n->children[0]->is_type<language::unary_plusplus>()) {
        auto expression     = std::move(n->children[1]);
        auto unary_operator = std::move(n->children[0]);
        unary_operator->children.emplace_back(std::move(expression));
        n = std::move(unary_operator);
        transform(n, st...);
      }
    }

    if (n->is_type<language::unary_expression>()) {
      const size_t child_nb = n->children.size();
      if (child_nb > 1) {
        if (n->children[child_nb - 1]->is_type<language::post_minusminus>() or
            n->children[child_nb - 1]->is_type<language::post_plusplus>()) {
          auto unary_operator = std::move(n->children[child_nb - 1]);
          n->children.pop_back();

          unary_operator->children.emplace_back(std::move(n));

          n = std::move(unary_operator);
          transform(n->children[0], st...);
        }
      }
    }

    if (n->is_type<language::unary_expression>() or n->is_type<language::name_subscript_expression>()) {
      const size_t child_nb = n->children.size();
      if (child_nb > 1) {
        if (n->children[1]->is_type<language::subscript_expression>()) {
          auto expression = std::move(n->children[0]);
          for (size_t i = 0; i < child_nb - 1; ++i) {
            n->children[i] = std::move(n->children[i + 1]);
          }

          auto& array_subscript_expression = n->children[0];
          n->children.pop_back();
          array_subscript_expression->children.emplace_back(std::move(expression));

          std::swap(array_subscript_expression->children[0], array_subscript_expression->children[1]);

          array_subscript_expression->m_begin = array_subscript_expression->children[0]->m_begin;

          transform(n, st...);
        }
      }
    }
  }
};

struct ASTBuilder::simplify_node_list : parse_tree::apply<ASTBuilder::simplify_node_list>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->is_type<language::name_list>() or n->is_type<language::lvalue_list>() or
        n->is_type<language::function_argument_list>()) {
      if (n->children.size() == 1) {
        n = std::move(n->children.back());
        transform(n, st...);
      }
    }
  }
};

struct ASTBuilder::simplify_statement_block : parse_tree::apply<ASTBuilder::simplify_statement_block>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if ((n->is_type<language::statement_block>() or n->is_type<language::block>()) and (n->children.size() == 1)) {
      if (not n->children[0]->is_type<language::var_declaration>()) {
        n = std::move(n->children.back());
        transform(n, st...);
      } else {
        n->set_type<language::block>();
      }
    }
  }
};

struct ASTBuilder::simplify_for_statement_block : parse_tree::apply<ASTBuilder::simplify_for_statement_block>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if ((n->is_type<language::for_statement_block>() or n->is_type<language::block>()) and (n->children.size() == 1)) {
      n = std::move(n->children.back());
      transform(n, st...);
    }
  }
};

struct ASTBuilder::simplify_for_init : parse_tree::apply<ASTBuilder::simplify_for_init>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

struct ASTBuilder::simplify_for_test : parse_tree::apply<ASTBuilder::simplify_for_test>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

struct ASTBuilder::simplify_for_post : parse_tree::apply<ASTBuilder::simplify_for_post>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

struct ASTBuilder::simplify_stream_statement : parse_tree::apply<ASTBuilder::simplify_stream_statement>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    for (size_t i = 1; i < n->children.size(); ++i) {
      n->children[0]->children.emplace_back(std::move(n->children[i]));
    }
    n = std::move(n->children[0]);
  }
};

template <typename Rule>
using selector = parse_tree::selector<
  Rule,
  parse_tree::store_content::on<language::import_instruction,
                                language::module_name,
                                language::true_kw,
                                language::false_kw,
                                language::integer,
                                language::real,
                                language::literal,
                                language::name,
                                language::B_set,
                                language::N_set,
                                language::Z_set,
                                language::R_set,
                                language::type_name_id,
                                language::tuple_expression,
                                language::vector_type,
                                language::string_type,
                                language::cout_kw,
                                language::cerr_kw,
                                language::clog_kw,
                                language::var_declaration,
                                language::fct_declaration,
                                language::type_mapping,
                                language::function_definition,
                                language::expression_list,
                                language::if_statement,
                                language::do_while_statement,
                                language::while_statement,
                                language::for_statement,
                                language::function_evaluation,
                                language::break_kw,
                                language::continue_kw>,
  ASTBuilder::rearrange::on<language::logical_or,
                            language::logical_and,
                            language::bitwise_xor,
                            language::equality,
                            language::compare,
                            language::sum,
                            language::product,
                            language::affectation,
                            language::expression>,
  ASTBuilder::simplify_unary::on<language::unary_minus,
                                 language::unary_plus,
                                 language::unary_not,
                                 language::subscript_expression,
                                 language::tuple_type_specifier,
                                 language::type_expression,
                                 language::unary_expression,
                                 language::name_subscript_expression>,
  parse_tree::remove_content::on<language::plus_op,
                                 language::minus_op,
                                 language::multiply_op,
                                 language::divide_op,
                                 language::lesser_op,
                                 language::lesser_or_eq_op,
                                 language::greater_op,
                                 language::greater_or_eq_op,
                                 language::eqeq_op,
                                 language::not_eq_op,
                                 language::and_op,
                                 language::or_op,
                                 language::xor_op,
                                 language::eq_op,
                                 language::multiplyeq_op,
                                 language::divideeq_op,
                                 language::pluseq_op,
                                 language::minuseq_op,
                                 language::unary_plusplus,
                                 language::unary_minusminus,
                                 language::post_minusminus,
                                 language::post_plusplus>,
  ASTBuilder::simplify_for_statement_block::on<language::for_statement_block>,
  parse_tree::discard_empty::on<language::ignored, language::semicol, language::block>,
  ASTBuilder::simplify_node_list::on<language::name_list, language::lvalue_list, language::function_argument_list>,
  ASTBuilder::simplify_statement_block::on<language::statement_block>,
  ASTBuilder::simplify_for_init::on<language::for_init>,
  ASTBuilder::simplify_for_test::on<language::for_test>,
  ASTBuilder::simplify_for_post::on<language::for_post>,
  ASTBuilder::simplify_stream_statement::on<language::ostream_statement>>;

template <typename InputT>
std::unique_ptr<ASTNode>
ASTBuilder::build(InputT& input)
{
  std::unique_ptr root_node = parse_tree::parse<language::grammar, ASTNode, selector, nothing, language::errors>(input);

  // build initial symbol tables
  std::shared_ptr symbol_table = std::make_shared<SymbolTable>();

  root_node->m_symbol_table = symbol_table;

  return root_node;
}

template std::unique_ptr<ASTNode> ASTBuilder::build(read_input<>& input);
template std::unique_ptr<ASTNode> ASTBuilder::build(string_input<>& input);