#include <ASTNodeExpressionBuilder.hpp>

#include <ASTNodeAffectationExpressionBuilder.hpp>
#include <ASTNodeBinaryOperatorExpressionBuilder.hpp>
#include <ASTNodeIncDecExpressionBuilder.hpp>
#include <ASTNodeUnaryOperatorExpressionBuilder.hpp>

#include <PEGGrammar.hpp>
#include <SymbolTable.hpp>

#include <Demangle.hpp>

class ASTNodeList final : public INodeProcessor
{
  ASTNode& m_node;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    for (auto& child : m_node.children) {
      child->execute(exec_policy);
    }
  }

  ASTNodeList(ASTNode& node) : m_node{node} {}
};

class NoProcess final : public INodeProcessor
{
 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  PUGS_INLINE
  void
  execute(ExecUntilBreakOrContinue&)
  {
    ;
  }

  NoProcess() = default;
};

class IfStatement final : public INodeProcessor
{
  ASTNode& m_node;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    m_node.children[0]->execute(exec_policy);
    const bool is_true = static_cast<bool>(std::visit(
      [](const auto& value) -> bool {
        using T = std::decay_t<decltype(value)>;
        if constexpr (std::is_arithmetic_v<T>) {
          return value;
        } else {
          return false;
        }
      },
      m_node.children[0]->m_value));
    if (is_true) {
      Assert(m_node.children[1] != nullptr);
      m_node.children[1]->execute(exec_policy);
    } else {
      if (m_node.children.size() == 3) {
        // else statement
        Assert(m_node.children[2] != nullptr);
        m_node.children[2]->execute(exec_policy);
      }
    }
  }

  IfStatement(ASTNode& node) : m_node{node} {}
};

class DoWhileStatement final : public INodeProcessor
{
  ASTNode& m_node;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    bool continuation_test = true;
    ExecUntilBreakOrContinue exec_until_jump;
    do {
      m_node.children[0]->execute(exec_until_jump);
      if (not exec_until_jump.exec()) {
        if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::break_jump) {
          break;
        } else if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::continue_jump) {
          exec_until_jump = ExecUntilBreakOrContinue{};   // getting ready for next loop traversal
        }
      }
      m_node.children[1]->execute(exec_policy);
      continuation_test = static_cast<bool>(std::visit(
        [](const auto& value) -> bool {
          using T = std::decay_t<decltype(value)>;
          if constexpr (std::is_arithmetic_v<T>) {
            return value;
          } else {
            return false;
          }
        },
        m_node.children[1]->m_value));
    } while (continuation_test);
  }

  DoWhileStatement(ASTNode& node) : m_node{node} {}
};

class WhileStatement final : public INodeProcessor
{
  ASTNode& m_node;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    ExecUntilBreakOrContinue exec_until_jump;
    while ([&]() {
      m_node.children[0]->execute(exec_policy);
      return static_cast<bool>(std::visit(
        [](const auto& value) -> bool {
          using T = std::decay_t<decltype(value)>;
          if constexpr (std::is_arithmetic_v<T>) {
            return value;
          } else {
            return false;
          }
        },
        m_node.children[0]->m_value));
    }()) {
      m_node.children[1]->execute(exec_until_jump);
      if (not exec_until_jump.exec()) {
        if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::break_jump) {
          break;
        } else if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::continue_jump) {
          exec_until_jump = ExecUntilBreakOrContinue{};   // getting ready for next loop traversal
        }
      }
    }
  }

  WhileStatement(ASTNode& node) : m_node{node} {}
};

class ForStatement final : public INodeProcessor
{
  ASTNode& m_node;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    ExecUntilBreakOrContinue exec_until_jump;
    m_node.children[0]->execute(exec_policy);
    while ([&]() {
      m_node.children[1]->execute(exec_policy);
      return static_cast<bool>(std::visit(
        [](const auto& value) -> bool {
          using T = std::decay_t<decltype(value)>;
          if constexpr (std::is_arithmetic_v<T>) {
            return value;
          } else {
            return false;
          }
        },
        m_node.children[1]->m_value));
    }()) {
      m_node.children[3]->execute(exec_until_jump);
      if (not exec_until_jump.exec()) {
        if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::break_jump) {
          break;
        } else if (exec_until_jump.jumpType() == ExecUntilBreakOrContinue::JumpType::continue_jump) {
          exec_until_jump = ExecUntilBreakOrContinue{};   // getting ready for next loop traversal
        }
      }

      m_node.children[2]->execute(exec_policy);
    }
  }

  ForStatement(ASTNode& node) : m_node{node} {}
};

class NameExpression final : public INodeProcessor
{
  ASTNode& m_node;
  ASTNodeDataVariant* p_value{nullptr};

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue&)
  {
    m_node.m_value = *p_value;
  }

  NameExpression(ASTNode& node) : m_node{node}
  {
    const std::string& symbol = m_node.string();
    auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.begin());
    Assert(found);
    p_value = &(i_symbol->second.value());
  }
};

class BreakExpression final : public INodeProcessor
{
 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    exec_policy = ExecUntilBreakOrContinue(ExecUntilBreakOrContinue::JumpType::break_jump);
  }

  BreakExpression() = default;
};

class ContinueExpression final : public INodeProcessor
{
 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    exec_policy = ExecUntilBreakOrContinue(ExecUntilBreakOrContinue::JumpType::continue_jump);
  }

  ContinueExpression() = default;
};

class OStreamObject final : public INodeProcessor
{
  ASTNode& m_node;
  std::ostream& m_os;

 public:
  std::string
  describe() const
  {
    return demangle<decltype(*this)>();
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    for (size_t i = 0; i < m_node.children.size(); ++i) {
      m_node.children[i]->execute(exec_policy);
      std::visit(
        [&](auto&& value) {
          using ValueT = std::decay_t<decltype(value)>;
          if constexpr (not std::is_same_v<std::monostate, ValueT>) {
            if constexpr (std::is_same_v<bool, ValueT>) {
              m_os << std::boolalpha << value;
            } else {
              m_os << value;
            }
          }
        },
        m_node.children[i]->m_value);
    }
  }

  OStreamObject(ASTNode& node, std::ostream& os) : m_node{node}, m_os(os)
  {
    ;
  }
};

void
ASTNodeExpressionBuilder::_buildExpression(ASTNode& n)
{
  if (n.is<language::bloc>()) {
    n.m_node_processor = std::make_unique<ASTNodeList>(n);
  } else if ((n.is<language::eq_op>() or n.is<language::multiplyeq_op>() or n.is<language::divideeq_op>() or
              n.is<language::pluseq_op>() or n.is<language::minuseq_op>())) {
    ASTNodeAffectationExpressionBuilder{n};

  } else if (n.is<language::real>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::integer>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::literal>()) {
    n.m_node_processor = std::make_unique<NoProcess>();

  } else if (n.is<language::name>()) {
    n.m_node_processor = std::make_unique<NameExpression>(n);

  } else if (n.is<language::unary_minus>() or n.is<language::unary_not>()) {
    ASTNodeUnaryOperatorExpressionBuilder{n};

  } else if (n.is<language::unary_minusminus>() or n.is<language::unary_plusplus>() or
             n.is<language::post_minusminus>() or n.is<language::post_plusplus>()) {
    ASTNodeIncDecExpressionBuilder{n};

  } else if (n.is<language::multiply_op>() or n.is<language::divide_op>() or n.is<language::plus_op>() or
             n.is<language::minus_op>() or n.is<language::or_op>() or n.is<language::and_op>() or
             n.is<language::xor_op>() or n.is<language::greater_op>() or n.is<language::greater_or_eq_op>() or
             n.is<language::lesser_op>() or n.is<language::lesser_or_eq_op>() or n.is<language::eqeq_op>() or
             n.is<language::not_eq_op>()) {
    ASTNodeBinaryOperatorExpressionBuilder{n};

  } else if (n.is<language::cout_kw>()) {
    n.m_node_processor = std::make_unique<OStreamObject>(n, std::cout);
  } else if (n.is<language::cerr_kw>()) {
    n.m_node_processor = std::make_unique<OStreamObject>(n, std::cerr);
  } else if (n.is<language::clog_kw>()) {
    n.m_node_processor = std::make_unique<OStreamObject>(n, std::clog);
  } else if (n.is<language::if_statement>()) {
    n.m_node_processor = std::make_unique<IfStatement>(n);
  } else if (n.is<language::statement_bloc>()) {
    n.m_node_processor = std::make_unique<ASTNodeList>(n);
  } else if (n.is<language::do_while_statement>()) {
    n.m_node_processor = std::make_unique<DoWhileStatement>(n);
  } else if (n.is<language::while_statement>()) {
    n.m_node_processor = std::make_unique<WhileStatement>(n);
  } else if (n.is<language::for_statement>()) {
    n.m_node_processor = std::make_unique<ForStatement>(n);
  } else if (n.is<language::for_statement_bloc>()) {
    n.m_node_processor = std::make_unique<ASTNodeList>(n);
  } else if (n.is<language::for_init>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::for_post>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::for_test>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::break_kw>()) {
    n.m_node_processor = std::make_unique<BreakExpression>();
  } else if (n.is<language::continue_kw>()) {
    n.m_node_processor = std::make_unique<ContinueExpression>();
  } else if (n.is<language::true_kw>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else if (n.is<language::false_kw>()) {
    n.m_node_processor = std::make_unique<NoProcess>();
  } else {
    std::ostringstream error_message;
    error_message << "undefined node type '" << rang::fgB::red << n.name() << rang::fg::reset << "'";
    throw parse_error{error_message.str(), std::vector{n.begin()}};
  }

  for (auto& child : n.children) {
    this->_buildExpression(*child);
  }
}

ASTNodeExpressionBuilder::ASTNodeExpressionBuilder(ASTNode& n)
{
  Assert(n.is_root());
  n.m_node_processor = std::make_unique<ASTNodeList>(n);
  for (auto& child : n.children) {
    this->_buildExpression(*child);
  }
  std::cout << " - build node types\n";
}