#include <language/ast/ASTBuilder.hpp>

#include <language/PEGGrammar.hpp>
#include <language/ast/ASTNode.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/PugsAssert.hpp>

#include <pegtl/contrib/parse_tree.hpp>

struct ASTBuilder::rearrange : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::rearrange>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    } else {
      // First we rearrange tree
      {
        auto& children = n->children;
        auto rhs       = std::move(children.back());
        children.pop_back();
        auto op = std::move(children.back());
        children.pop_back();
        op->children.emplace_back(std::move(n));
        op->children.emplace_back(std::move(rhs));
        n = std::move(op);
        transform(n->children.front(), st...);
      }
      // Then we eventually simplify operations
      {
        if (n->is_type<language::minus_op>()) {
          Assert(n->children.size() == 2);
          auto& rhs = n->children[1];
          if (rhs->is_type<language::unary_minus>()) {
            n->set_type<language::plus_op>();
            rhs = std::move(rhs->children[0]);
          }
        } else if (n->is_type<language::plus_op>()) {
          Assert(n->children.size() == 2);
          auto& rhs = n->children[1];
          if (rhs->is_type<language::unary_minus>()) {
            n->set_type<language::minus_op>();
            rhs = std::move(rhs->children[0]);
          }
        }
      }
    }
  }
};

struct ASTBuilder::simplify_unary : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_unary>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->children.size() == 1) {
      if (n->is_type<language::unary_expression>() or n->is_type<language::type_expression>() or
          n->is_type<language::name_subscript_expression>()) {
        n = std::move(n->children.back());
        transform(n, st...);
      } else if (n->is_type<language::unary_minus>()) {
        auto& child = n->children[0];
        if (child->is_type<language::unary_minus>()) {
          n = std::move(child->children[0]);
          transform(n, st...);
        }
      } else if (n->is_type<language::unary_not>()) {
        auto& child = n->children[0];
        if (child->is_type<language::unary_not>()) {
          n = std::move(child->children[0]);
          transform(n, st...);
        }
      }
    } else if ((n->children.size() == 2) and (n->is_type<language::unary_expression>())) {
      if (n->children[0]->is_type<language::unary_plus>()) {
        n = std::move(n->children[1]);
        transform(n, st...);
      } else if (n->children[0]->is_type<language::unary_minus>() or n->children[0]->is_type<language::unary_not>() or
                 n->children[0]->is_type<language::unary_minusminus>() or
                 n->children[0]->is_type<language::unary_plusplus>()) {
        auto expression     = std::move(n->children[1]);
        auto unary_operator = std::move(n->children[0]);
        unary_operator->children.emplace_back(std::move(expression));
        n = std::move(unary_operator);
        transform(n, st...);
      }
    }

    if (n->is_type<language::unary_expression>()) {
      const size_t child_nb = n->children.size();
      if (child_nb > 1) {
        if (n->children[child_nb - 1]->is_type<language::post_minusminus>() or
            n->children[child_nb - 1]->is_type<language::post_plusplus>()) {
          auto unary_operator = std::move(n->children[child_nb - 1]);
          n->children.pop_back();

          unary_operator->children.emplace_back(std::move(n));

          n = std::move(unary_operator);
          transform(n->children[0], st...);
        }
      }
    }

    if (n->is_type<language::unary_expression>() or n->is_type<language::name_subscript_expression>()) {
      if (n->children.size() > 1) {
        if (n->children[1]->is_type<language::subscript_expression>()) {
          std::swap(n->children[0], n->children[1]);

          n->children[0]->emplace_back(std::move(n->children[1]));
          n->children.pop_back();

          auto& array_subscript_expression = n->children[0];
          const size_t child_nb            = array_subscript_expression->children.size();
          for (size_t i = 1; i < array_subscript_expression->children.size(); ++i) {
            std::swap(array_subscript_expression->children[child_nb - i],
                      array_subscript_expression->children[child_nb - i - 1]);
          }

          transform(n, st...);
        }
      }
    }
  }
};

struct ASTBuilder::simplify_node_list : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_node_list>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if (n->is_type<language::name_list>() or n->is_type<language::lvalue_list>() or
        n->is_type<language::function_argument_list>()) {
      if (n->children.size() == 1) {
        n = std::move(n->children.back());
        transform(n, st...);
      }
    }
  }
};

struct ASTBuilder::simplify_statement_block
  : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_statement_block>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if ((n->is_type<language::statement_block>() or n->is_type<language::block>()) and (n->children.size() == 1)) {
      if (not n->children[0]->is_type<language::var_declaration>()) {
        n = std::move(n->children.back());
        transform(n, st...);
      } else {
        n->set_type<language::block>();
      }
    }
  }
};

struct ASTBuilder::simplify_for_statement_block
  : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_for_statement_block>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&... st)
  {
    if ((n->is_type<language::for_statement_block>() or n->is_type<language::block>()) and (n->children.size() == 1)) {
      n = std::move(n->children.back());
      transform(n, st...);
    }
  }
};

struct ASTBuilder::simplify_for_init : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_for_init>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

struct ASTBuilder::simplify_for_test : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_for_test>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

struct ASTBuilder::simplify_for_post : TAO_PEGTL_NAMESPACE::parse_tree::apply<ASTBuilder::simplify_for_post>
{
  template <typename... States>
  static void
  transform(std::unique_ptr<ASTNode>& n, States&&...)
  {
    Assert(n->children.size() <= 1);
    if (n->children.size() == 1) {
      n = std::move(n->children.back());
    }
  }
};

template <typename Rule>
using selector = TAO_PEGTL_NAMESPACE::parse_tree::selector<
  Rule,
  TAO_PEGTL_NAMESPACE::parse_tree::store_content::on<language::import_instruction,
                                                     language::module_name,
                                                     language::true_kw,
                                                     language::false_kw,
                                                     language::integer,
                                                     language::real,
                                                     language::literal,
                                                     language::name,
                                                     language::B_set,
                                                     language::N_set,
                                                     language::Z_set,
                                                     language::R_set,
                                                     language::type_name_id,
                                                     language::tuple_expression,
                                                     language::vector_type,
                                                     language::matrix_type,
                                                     language::string_type,
                                                     language::var_declaration,
                                                     language::fct_declaration,
                                                     language::type_mapping,
                                                     language::function_definition,
                                                     language::expression_list,
                                                     language::if_statement,
                                                     language::do_while_statement,
                                                     language::while_statement,
                                                     language::for_statement,
                                                     language::function_evaluation,
                                                     language::break_kw,
                                                     language::continue_kw>,
  ASTBuilder::rearrange::on<language::logical_or,
                            language::logical_and,
                            language::bitwise_xor,
                            language::equality,
                            language::compare,
                            language::sum,
                            language::shift,
                            language::product,
                            language::affectation,
                            language::expression>,
  ASTBuilder::simplify_unary::on<language::unary_minus,
                                 language::unary_plus,
                                 language::unary_not,
                                 language::subscript_expression,
                                 language::tuple_type_specifier,
                                 language::type_expression,
                                 language::unary_expression,
                                 language::name_subscript_expression>,
  TAO_PEGTL_NAMESPACE::parse_tree::remove_content::on<language::plus_op,
                                                      language::minus_op,
                                                      language::shift_left_op,
                                                      language::shift_right_op,
                                                      language::multiply_op,
                                                      language::divide_op,
                                                      language::lesser_op,
                                                      language::lesser_or_eq_op,
                                                      language::greater_op,
                                                      language::greater_or_eq_op,
                                                      language::eqeq_op,
                                                      language::not_eq_op,
                                                      language::and_op,
                                                      language::or_op,
                                                      language::xor_op,
                                                      language::eq_op,
                                                      language::multiplyeq_op,
                                                      language::divideeq_op,
                                                      language::pluseq_op,
                                                      language::minuseq_op,
                                                      language::unary_plusplus,
                                                      language::unary_minusminus,
                                                      language::post_minusminus,
                                                      language::post_plusplus>,
  ASTBuilder::simplify_for_statement_block::on<language::for_statement_block>,
  TAO_PEGTL_NAMESPACE::parse_tree::discard_empty::on<language::ignored, language::semicol, language::block>,
  ASTBuilder::simplify_node_list::on<language::name_list, language::lvalue_list, language::function_argument_list>,
  ASTBuilder::simplify_statement_block::on<language::statement_block>,
  ASTBuilder::simplify_for_init::on<language::for_init>,
  ASTBuilder::simplify_for_test::on<language::for_test>,
  ASTBuilder::simplify_for_post::on<language::for_post>>;

template <typename InputT>
std::unique_ptr<ASTNode>
ASTBuilder::build(InputT& input)
{
  std::unique_ptr root_node =
    TAO_PEGTL_NAMESPACE::parse_tree::parse<language::grammar, ASTNode, selector, TAO_PEGTL_NAMESPACE::nothing,
                                           language::errors>(input);

  // build initial symbol tables
  std::shared_ptr symbol_table = std::make_shared<SymbolTable>();

  root_node->m_symbol_table = symbol_table;

  return root_node;
}

template std::unique_ptr<ASTNode> ASTBuilder::build(TAO_PEGTL_NAMESPACE::read_input<>& input);
template std::unique_ptr<ASTNode> ASTBuilder::build(TAO_PEGTL_NAMESPACE::string_input<>& input);
