#include <PugsParser.hpp>

#include <PugsAssert.hpp>

#include <fstream>
#include <iostream>
#include <unordered_map>
#include <variant>

#include <rang.hpp>

#include <pegtl/analyze.hpp>
#include <pegtl/contrib/parse_tree.hpp>
#include <pegtl/contrib/parse_tree_to_dot.hpp>

#include <ASTNode.hpp>

#include <ASTBuilder.hpp>
#include <PEGGrammar.hpp>
#include <SymbolTable.hpp>

#include <EscapedString.hpp>

#include <ASTNodeExpressionBuilder.hpp>

namespace language
{
namespace internal
{
void
print_dot(std::ostream& os, const Node& n)
{
  if (n.is_root()) {
    os << "  x" << &n << " [ label=\"root \\n" << dataTypeName(n.m_data_type) << "\" ]\n";
  } else {
    if (n.has_content()) {
      os << "  x" << &n << " [ label=\"" << n.name() << "\\n"
         << n.string_view() << "\\n"
         << dataTypeName(n.m_data_type) << "\" ]\n";
    } else {
      os << "  x" << &n << " [ label=\"" << n.name() << "\\n" << dataTypeName(n.m_data_type) << "\" ]\n";
    }
  }
  if (!n.children.empty()) {
    os << "  x" << &n << " -> { ";
    for (auto& child : n.children) {
      os << "x" << child.get() << ((child == n.children.back()) ? " }\n" : ", ");
    }
    for (auto& child : n.children) {
      print_dot(os, *child);
    }
  }
}

}   // namespace internal

void
print_dot(std::ostream& os, const Node& n)
{
  Assert(n.is_root());
  os << "digraph parse_tree\n{\n";
  internal::print_dot(os, n);
  os << "}\n";
}

namespace internal
{
void
build_symbol_table_and_check_declarations(Node& n, std::shared_ptr<SymbolTable>& symbol_table)
{
  if (n.is<language::bloc>() or (n.is<language::for_statement>())) {
    if (!n.children.empty()) {
      std::shared_ptr bloc_symbol_table = std::make_shared<SymbolTable>(symbol_table);
      n.m_symbol_table                  = bloc_symbol_table;
      for (auto& child : n.children) {
        build_symbol_table_and_check_declarations(*child, bloc_symbol_table);
      }
    }
  } else {
    n.m_symbol_table = symbol_table;
    if (n.has_content()) {
      if (n.is<language::declaration>()) {
        const std::string& symbol = n.children[1]->string();
        auto [i_symbol, success]  = symbol_table->add(symbol);
        if (not success) {
          std::ostringstream error_message;
          error_message << "symbol '" << rang::fg::red << symbol << rang::fg::reset << '\'' << " was already defined!";
          throw parse_error(error_message.str(), std::vector{n.begin()});
        }
      } else if (n.is<language::name>()) {
        auto [i_symbol, found] = symbol_table->find(n.string());
        if (not found) {
          std::ostringstream error_message;
          error_message << "undefined symbol '" << rang::fg::red << n.string() << rang::fg::reset << '\'';
          throw parse_error(error_message.str(), std::vector{n.begin()});
        }
      }
    }

    for (auto& child : n.children) {
      build_symbol_table_and_check_declarations(*child, symbol_table);
    }
  }
}
}   // namespace internal

void
build_symbol_table_and_check_declarations(Node& n)
{
  Assert(n.is_root());
  std::shared_ptr symbol_table = std::make_shared<SymbolTable>();
  n.m_symbol_table             = symbol_table;
  internal::build_symbol_table_and_check_declarations(n, symbol_table);
  std::cout << " - checked symbols declaration\n";
}

namespace internal
{
void
check_symbol_initialization(const Node& n, std::shared_ptr<SymbolTable>& symbol_table)
{
  if (n.is<language::bloc>() or n.is<language::for_statement>()) {
    if (!n.children.empty()) {
      std::shared_ptr bloc_symbol_table = std::make_shared<SymbolTable>(symbol_table);
      for (auto& child : n.children) {
        check_symbol_initialization(*child, bloc_symbol_table);
      }
    }
  } else {
    if (n.is<language::declaration>()) {
      const std::string& symbol = n.children[1]->string();
      auto [i_symbol, success]  = symbol_table->add(symbol);
      Assert(success, "unexpected error, should have been detected through declaration checking");
      if (n.children.size() == 3) {
        check_symbol_initialization(*n.children[2], symbol_table);
        i_symbol->second.setIsInitialized();
      }
    } else if (n.is<language::eq_op>()) {
      // first checks for right hand side
      check_symbol_initialization(*n.children[1], symbol_table);
      // then marks left hand side as initialized
      const std::string& symbol = n.children[0]->string();
      auto [i_symbol, found]    = symbol_table->find(symbol);
      Assert(found, "unexpected error, should have been detected through declaration checking");
      i_symbol->second.setIsInitialized();
    } else if (n.is<language::name>()) {
      auto [i_symbol, found] = symbol_table->find(n.string());
      Assert(found, "unexpected error, should have been detected through declaration checking");
      if (not i_symbol->second.isInitialized()) {
        std::ostringstream error_message;
        error_message << "uninitialized symbol '" << rang::fg::red << n.string() << rang::fg::reset << '\'';
        throw parse_error(error_message.str(), std::vector{n.begin()});
      }
    }

    if ((not n.is<language::declaration>()) and (not n.is<language::eq_op>())) {
      for (auto& child : n.children) {
        check_symbol_initialization(*child, symbol_table);
      }
    }
  }
}
}   // namespace internal

void
check_symbol_initialization(const Node& n)
{
  std::cerr << rang::fgB::yellow << "warning:" << rang::fg::reset
            << " symbol initialization checking not finished"
               "if and loops statements are not correctly evaluated\n";
  Assert(n.is_root());
  std::shared_ptr symbol_table = std::make_shared<SymbolTable>();
  internal::check_symbol_initialization(n, symbol_table);
  std::cout << " - checked symbols initialization\n";
}

namespace internal
{
void
build_node_data_types(Node& n)
{
  if (n.is<language::bloc>() or n.is<for_statement>()) {
    if (!n.children.empty()) {
      for (auto& child : n.children) {
        build_node_data_types(*child);
      }
    }
    n.m_data_type = DataType::void_t;
  } else {
    if (n.has_content()) {
      if (n.is<language::true_kw>() or n.is<language::false_kw>() or n.is<language::do_kw>()) {
        n.m_data_type = DataType::bool_t;
      } else if (n.is<language::real>()) {
        n.m_data_type = DataType::double_t;
      } else if (n.is<language::integer>()) {
        n.m_data_type = DataType::int_t;
      } else if (n.is<language::literal>()) {
        n.m_data_type = DataType::string_t;
      } else if (n.is<language::cout_kw>() or n.is<language::cerr_kw>() or n.is<language::clog_kw>()) {
        n.m_data_type = DataType::void_t;
      } else if (n.is<language::declaration>()) {
        auto& type_node = *(n.children[0]);
        DataType data_type{DataType::undefined_t};
        if (type_node.is<language::B_set>()) {
          data_type = DataType::bool_t;
        } else if (type_node.is<language::Z_set>()) {
          data_type = DataType::int_t;
        } else if (type_node.is<language::N_set>()) {
          data_type = DataType::unsigned_int_t;
        } else if (type_node.is<language::R_set>()) {
          data_type = DataType::double_t;
        } else if (type_node.is<language::string_type>()) {
          data_type = DataType::string_t;
        }
        if (data_type == DataType::undefined_t) {
          throw parse_error("unexpected error: invalid datatype", type_node.begin());
        }
        type_node.m_data_type      = DataType::void_t;
        n.children[1]->m_data_type = data_type;
        const std::string& symbol  = n.children[1]->string();

        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(symbol);
        Assert(found);
        i_symbol->second.setDataType(data_type);
        n.m_data_type = data_type;
      } else if (n.is<language::name>()) {
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;

        auto [i_symbol, found] = symbol_table->find(n.string());
        Assert(found);
        n.m_data_type = i_symbol->second.dataType();
      }
    }
    for (auto& child : n.children) {
      build_node_data_types(*child);
    }

    if (n.is<language::break_kw>() or n.is<language::continue_kw>()) {
      n.m_data_type = DataType::void_t;
    } else if (n.is<language::eq_op>() or n.is<language::multiplyeq_op>() or n.is<language::divideeq_op>() or
               n.is<language::pluseq_op>() or n.is<language::minuseq_op>() or n.is<language::bit_andeq_op>() or
               n.is<language::bit_xoreq_op>() or n.is<language::bit_oreq_op>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is<language::for_statement>()) {
      n.m_data_type = DataType::void_t;
    } else if (n.is<language::for_post>() or n.is<language::for_init>() or n.is<language::for_statement_bloc>()) {
      n.m_data_type = DataType::void_t;
    } else if (n.is<language::for_test>()) {
      n.m_data_type = DataType::bool_t;
    } else if (n.is<language::statement_bloc>()) {
      n.m_data_type = DataType::void_t;
    } else if (n.is<language::if_statement>() or n.is<language::while_statement>()) {
      n.m_data_type = DataType::void_t;
      if ((n.children[0]->m_data_type > DataType::double_t) or (n.children[0]->m_data_type < DataType::bool_t)) {
        const DataType type_0 = n.children[0]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[0]->string() << " of type ' (" << dataTypeName(type_0)
                << ')' << std::ends;
        throw parse_error(message.str(), n.children[0]->begin());
      }
    } else if (n.is<language::do_while_statement>()) {
      n.m_data_type = DataType::void_t;
      if ((n.children[1]->m_data_type > DataType::double_t) or (n.children[1]->m_data_type < DataType::bool_t)) {
        const DataType type_0 = n.children[1]->m_data_type;
        std::ostringstream message;
        message << "Cannot convert data type to boolean value\n"
                << "note: incompatible operand '" << n.children[1]->string() << " of type ' (" << dataTypeName(type_0)
                << ')' << std::ends;
        throw parse_error(message.str(), n.children[1]->begin());
      }
    } else if (n.is<language::unary_not>() or n.is<language::lesser_op>() or n.is<language::lesser_or_eq_op>() or
               n.is<language::greater_op>() or n.is<language::greater_or_eq_op>() or n.is<language::eqeq_op>() or
               n.is<language::not_eq_op>() or n.is<language::and_op>() or n.is<language::or_op>() or
               n.is<language::xor_op>() or n.is<language::bitand_op>() or n.is<language::bitor_op>()) {
      n.m_data_type = DataType::bool_t;
    } else if (n.is<language::unary_minus>() or n.is<language::unary_plus>() or n.is<language::unary_plusplus>() or
               n.is<language::unary_minusminus>()) {
      n.m_data_type = n.children[0]->m_data_type;
    } else if (n.is<language::plus_op>() or n.is<language::minus_op>() or n.is<language::multiply_op>() or
               n.is<language::divide_op>()) {
      const DataType type_0 = n.children[0]->m_data_type;
      const DataType type_1 = n.children[1]->m_data_type;

      n.m_data_type = dataTypePromotion(type_0, type_1);
      if (n.m_data_type == DataType::undefined_t) {
        std::ostringstream message;
        message << "undefined binary operator\n"
                << "note: incompatible operand types " << n.children[0]->string() << " (" << dataTypeName(type_0)
                << ") and " << n.children[1]->string() << " (" << dataTypeName(type_1) << ')' << std::ends;
        throw parse_error(message.str(), n.begin());
      }
    }
  }
}
}   // namespace internal

void
build_node_data_types(Node& n)
{
  Assert(n.is_root());
  n.m_data_type = DataType::void_t;

  internal::build_node_data_types(n);
  std::cout << " - build node data types\n";
}

namespace internal
{
void
check_node_data_types(const Node& n)
{
  if (n.m_data_type == DataType::undefined_t) {
    throw parse_error("unexpected error: undefined datatype for AST node for " + n.name(), n.begin());
  }

  for (auto& child : n.children) {
    check_node_data_types(*child);
  }
}
}   // namespace internal

void
check_node_data_types(const Node& n)
{
  Assert(n.is_root());
  internal::check_node_data_types(n);
  std::cout << " - checked node data types\n";
}

namespace internal
{
void
build_node_values(Node& n, std::shared_ptr<SymbolTable>& symbol_table)
{
  if (n.is<language::bloc>()) {
    if (!n.children.empty()) {
      std::shared_ptr bloc_symbol_table = std::make_shared<SymbolTable>(symbol_table);
      for (auto& child : n.children) {
        build_node_values(*child, bloc_symbol_table);
      }
    }
    n.m_data_type = DataType::void_t;
  } else {
    for (auto& child : n.children) {
      build_node_values(*child, symbol_table);
    }

    if (n.has_content()) {
      if (n.is<language::real>()) {
        std::stringstream ss(n.string());
        double v;
        ss >> v;
        n.m_value = v;
      } else if (n.is<language::integer>()) {
        std::stringstream ss(n.string());
        int64_t v;
        ss >> v;
        n.m_value = v;
      } else if (n.is<language::literal>()) {
        n.m_value = unescapeString(n.string());
      } else if (n.is<language::for_test>()) {
        // if AST contains a for_test statement, it means that no test were
        // given to the for-loop, so its value is always true
        n.m_value = true;
      } else if (n.is<language::true_kw>()) {
        n.m_value = true;
      } else if (n.is<language::false_kw>()) {
        n.m_value = false;
      }
    }
  }
}
}   // namespace internal

void
build_node_values(Node& n)
{
  Assert(n.is_root());
  n.m_data_type                = DataType::void_t;
  std::shared_ptr symbol_table = std::make_shared<SymbolTable>();
  internal::build_node_values(n, symbol_table);
  std::cout << " - build node data types\n";
}

namespace internal
{
void
check_break_or_continue_placement(const Node& n, bool is_inside_loop)
{
  if (n.is<language::for_statement>() or n.is<language::do_while_statement>() or n.is<language::while_statement>()) {
    for (auto& child : n.children) {
      check_break_or_continue_placement(*child, true);
    }
  } else if (n.is<language::break_kw>() or n.is<language::continue_kw>()) {
    if (not is_inside_loop) {
      std::ostringstream error_message;
      error_message << "unexpected '" << rang::fgB::red << n.string() << rang::fg::reset
                    << "' outside of loop or switch statement";
      throw parse_error(error_message.str(), std::vector{n.begin()});
    }
  } else {
    for (auto& child : n.children) {
      check_break_or_continue_placement(*child, is_inside_loop);
    }
  }
}
}   // namespace internal

void
check_break_or_continue_placement(const Node& n)
{
  Assert(n.is_root());
  internal::check_break_or_continue_placement(n, false);
}

namespace internal
{
void
simplify_declarations(Node& n)
{
  if (n.is<language::declaration>()) {
    if (n.children.size() == 3) {
      n.children[0] = std::move(n.children[1]);
      n.children[1] = std::move(n.children[2]);
      n.children.resize(2);
      n.id = typeid(language::eq_op);
    }
  } else {
    for (auto& child : n.children) {
      simplify_declarations(*child);
    }
  }
}
}   // namespace internal

void
simplify_declarations(Node& n)
{
  Assert(n.is_root());
  internal::simplify_declarations(n);
}

void print(const Node& n);

std::string prefix;
std::vector<int> last_prefix_size;
const std::string T_junction(" \u251c\u2500\u2500");
const std::string L_junction(" \u2514\u2500\u2500");

const std::string pipe_space(" \u2502  ");
const std::string space_space("    ");

template <typename NodeVector>
void
print(const NodeVector& node_list)
{
  for (size_t i_child = 0; i_child < node_list.size(); ++i_child) {
    if (i_child != node_list.size() - 1) {
      std::cout << rang::fgB::green << prefix << T_junction << rang::fg::reset;
    } else {
      std::cout << rang::fgB::green << prefix << L_junction << rang::fg::reset;
    }
    auto& child = *(node_list[i_child]);
    if (not child.children.empty()) {
      last_prefix_size.push_back(prefix.size());
      if (i_child != node_list.size() - 1) {
        prefix += pipe_space;
      } else {
        prefix += space_space;
      }

      print(*(node_list[i_child]));

      prefix.resize(last_prefix_size[last_prefix_size.size() - 1]);
      last_prefix_size.pop_back();
    } else {
      print(*(node_list[i_child]));
    }
  }
}

void
print(const Node& n)
{
  std::cout << '(' << rang::fgB::yellow;
  if (n.is_root()) {
    std::cout << "root";
  } else {
    std::cout << n.name();
  }
  std::cout << rang::fg::reset << ':';
  std::cout << dataTypeName(n.m_data_type) << ':';

  std::cout << rang::fgB::cyan;
  std::visit(
    [](const auto& value) {
      using T = std::decay_t<decltype(value)>;
      if constexpr (std::is_same_v<T, std::monostate>) {
        std::cout << "--";
      } else if constexpr (std::is_same_v<T, std::string>) {
        std::cout << '\"' << escapeString(value) << '\"';
      } else {
        std::cout << value;
      }
    },
    n.m_value);
  std::cout << rang::fg::reset << ")\n";

  if (not n.children.empty()) {
    print(n.children);
  }
}

}   // namespace language

void
parser(const std::string& filename)
{
  const size_t grammar_issues = analyze<language::grammar>();

  std::cout << rang::fgB::yellow << "grammar_issues=" << rang::fg::reset << grammar_issues << '\n';

  std::cout << rang::style::bold << "Parsing file " << rang::style::reset << rang::style::underline << filename
            << rang::style::reset << " ...\n";

  std::unique_ptr<language::Node> root_node;
  read_input input(filename);
  try {
    root_node = buildAST(input);
    std::cout << " - AST is built ...... [done]\n";

    language::build_symbol_table_and_check_declarations(*root_node);
    language::check_symbol_initialization(*root_node);
    {
      std::string dot_filename{"parse_tree.dot"};
      std::ofstream fout(dot_filename);
      language::print_dot(fout, *root_node);
      std::cout << "   AST dot file: " << dot_filename << '\n';
    }

    language::build_node_data_types(*root_node);

    language::check_node_data_types(*root_node);
    language::build_node_values(*root_node);

    language::check_break_or_continue_placement(*root_node);

    // optimizations
    language::simplify_declarations(*root_node);

    language::build_node_type(*root_node);

    language::print(*root_node);

    language::ExecUntilBreakOrContinue exec_all;
    root_node->execute(exec_all);
    std::cout << *(root_node->m_symbol_table) << '\n';
  }
  catch (const parse_error& e) {
    const auto p = e.positions.front();
    std::cerr << rang::style::bold << p.source << ':' << p.line << ':' << p.byte_in_line << ": " << rang::style::reset
              << rang::fgB::red << "error: " << rang::fg::reset << rang::style::bold << e.what() << rang::style::reset
              << '\n'
              << input.line_at(p) << '\n'
              << std::string(p.byte_in_line, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << std::endl;
    std::exit(1);
  }

  std::cout << "Parsed: " << filename << '\n';
}
