#include <ASTNodeBinaryOperatorExpressionBuilder.hpp>
#include <PEGGrammar.hpp>
#include <SymbolTable.hpp>

#include <type_traits>

namespace language
{
template <typename Op>
struct BinOp;

template <>
struct BinOp<language::and_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a and b)
  {
    return a and b;
  }
};

template <>
struct BinOp<language::or_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a or b)
  {
    return a or b;
  }
};

template <>
struct BinOp<language::xor_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a xor b)
  {
    return a xor b;
  }
};

template <>
struct BinOp<language::bitand_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a & b)
  {
    return a & b;
  }
};

template <>
struct BinOp<language::bitor_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a | b)
  {
    return a | b;
  }
};

template <>
struct BinOp<language::eqeq_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a == b)
  {
    return a == b;
  }
};

template <>
struct BinOp<language::not_eq_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a != b)
  {
    return a != b;
  }
};

template <>
struct BinOp<language::lesser_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a < b)
  {
    return a < b;
  }
};

template <>
struct BinOp<language::lesser_or_eq_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a <= b)
  {
    return a <= b;
  }
};

template <>
struct BinOp<language::greater_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a > b)
  {
    return a > b;
  }
};

template <>
struct BinOp<language::greater_or_eq_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a >= b)
  {
    return a >= b;
  }
};

template <>
struct BinOp<language::plus_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a + b)
  {
    return a + b;
  }
};

template <>
struct BinOp<language::minus_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a - b)
  {
    return a - b;
  }
};

template <>
struct BinOp<language::multiply_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a * b)
  {
    return a * b;
  }
};

template <>
struct BinOp<language::divide_op>
{
  template <typename A, typename B>
  PUGS_INLINE auto
  eval(const A& a, const B& b) -> decltype(a / b)
  {
    return a / b;
  }
};

template <typename BinaryOpT, typename A_DataT, typename B_DataT>
class BinaryExpressionProcessor final : public INodeProcessor
{
  Node& m_node;

  PUGS_INLINE auto
  eval(const DataVariant& a, const DataVariant& b, DataVariant& value)
  {
    // Add 'signed' when necessary to avoid signed/unsigned comparison warnings
    if constexpr ((not(std::is_same_v<A_DataT, bool> or std::is_same_v<B_DataT, bool>)) and
                  (std::is_same_v<BinaryOpT, language::and_op> or std::is_same_v<BinaryOpT, language::or_op> or
                   std::is_same_v<BinaryOpT, language::xor_op> or std::is_same_v<BinaryOpT, language::bitand_op> or
                   std::is_same_v<BinaryOpT, language::bitor_op> or std::is_same_v<BinaryOpT, language::eqeq_op> or
                   std::is_same_v<BinaryOpT, language::not_eq_op> or std::is_same_v<BinaryOpT, language::lesser_op> or
                   std::is_same_v<BinaryOpT, language::lesser_or_eq_op> or
                   std::is_same_v<BinaryOpT, language::greater_op> or
                   std::is_same_v<BinaryOpT, language::greater_or_eq_op>) and
                  (std::is_signed_v<A_DataT> xor std::is_signed_v<B_DataT>)) {
      if constexpr (std::is_unsigned_v<A_DataT>) {
        using signed_A_DataT          = std::make_signed_t<A_DataT>;
        const signed_A_DataT signed_a = static_cast<signed_A_DataT>(std::get<A_DataT>(a));
        value                         = BinOp<BinaryOpT>().eval(signed_a, std::get<B_DataT>(b));
      } else {
        using signed_B_DataT          = std::make_signed_t<B_DataT>;
        const signed_B_DataT signed_b = static_cast<signed_B_DataT>(std::get<B_DataT>(b));
        value                         = BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), signed_b);
      }
    } else {
      auto result = BinOp<BinaryOpT>().eval(std::get<A_DataT>(a), std::get<B_DataT>(b));
      if constexpr (std::is_same_v<decltype(result), int>) {
        value = static_cast<int64_t>(result);
      } else {
        value = result;
      }
    }
  }

  static inline const bool _is_defined{[] {
    if constexpr (std::is_same_v<BinaryOpT, language::bitand_op> or std::is_same_v<BinaryOpT, language::xor_op> or
                  std::is_same_v<BinaryOpT, language::bitor_op>) {
      return std::is_same_v<std::decay_t<A_DataT>, std::decay_t<B_DataT>> and std::is_integral_v<std::decay_t<A_DataT>>;
    }
    return true;
  }()};

 public:
  BinaryExpressionProcessor(Node& node) : m_node{node}
  {
    if constexpr (not _is_defined) {
      throw parse_error("invalid operands to binary expression", std::vector{m_node.begin()});
    }
  }

  void
  execute(ExecUntilBreakOrContinue& exec_policy)
  {
    if constexpr (_is_defined) {
      m_node.children[0]->execute(exec_policy);
      m_node.children[1]->execute(exec_policy);

      this->eval(m_node.children[0]->m_value, m_node.children[1]->m_value, m_node.m_value);
    }
  }
};

ASTNodeBinaryOperatorExpressionBuilder::ASTNodeBinaryOperatorExpressionBuilder(Node& n)
{
  auto set_binary_operator_processor = [](Node& n, const auto& operator_v) {
    auto set_binary_operator_processor_for_data_b = [&](const auto data_a, const DataType& data_type_b) {
      using OperatorT = std::decay_t<decltype(operator_v)>;
      using DataTA    = std::decay_t<decltype(data_a)>;
      switch (data_type_b) {
      case DataType::bool_t: {
        n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, bool>>(n);
        break;
      }
      case DataType::unsigned_int_t: {
        n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, uint64_t>>(n);
        break;
      }
      case DataType::int_t: {
        n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, int64_t>>(n);
        break;
      }
      case DataType::double_t: {
        n.m_node_processor = std::make_unique<BinaryExpressionProcessor<OperatorT, DataTA, double>>(n);
        break;
      }
      default: {
        throw parse_error("undefined operand type for binary operator", std::vector{n.children[1]->begin()});
      }
      }
    };

    auto set_binary_operator_processor_for_data_a = [&](const DataType& data_type_a) {
      const DataType data_type_b = n.children[1]->m_data_type;
      switch (data_type_a) {
      case DataType::bool_t: {
        set_binary_operator_processor_for_data_b(bool{}, data_type_b);
        break;
      }
      case DataType::unsigned_int_t: {
        set_binary_operator_processor_for_data_b(uint64_t{}, data_type_b);
        break;
      }
      case DataType::int_t: {
        set_binary_operator_processor_for_data_b(int64_t{}, data_type_b);
        break;
      }
      case DataType::double_t: {
        set_binary_operator_processor_for_data_b(double{}, data_type_b);
        break;
      }
      default: {
        throw parse_error("undefined operand type for binary operator", std::vector{n.children[0]->begin()});
      }
      }
    };

    set_binary_operator_processor_for_data_a(n.children[0]->m_data_type);
  };

  if (n.is<language::multiply_op>()) {
    set_binary_operator_processor(n, language::multiply_op{});
  } else if (n.is<language::divide_op>()) {
    set_binary_operator_processor(n, language::divide_op{});
  } else if (n.is<language::plus_op>()) {
    set_binary_operator_processor(n, language::plus_op{});
  } else if (n.is<language::minus_op>()) {
    set_binary_operator_processor(n, language::minus_op{});
  } else if (n.is<language::or_op>()) {
    set_binary_operator_processor(n, language::or_op{});
  } else if (n.is<language::and_op>()) {
    set_binary_operator_processor(n, language::and_op{});

  } else if (n.is<language::xor_op>()) {
    set_binary_operator_processor(n, language::xor_op{});
  } else if (n.is<language::bitand_op>()) {
    set_binary_operator_processor(n, language::bitand_op{});
  } else if (n.is<language::bitor_op>()) {
    set_binary_operator_processor(n, language::bitor_op{});

  } else if (n.is<language::greater_op>()) {
    set_binary_operator_processor(n, language::greater_op{});
  } else if (n.is<language::greater_or_eq_op>()) {
    set_binary_operator_processor(n, language::greater_or_eq_op{});
  } else if (n.is<language::lesser_op>()) {
    set_binary_operator_processor(n, language::lesser_op{});
  } else if (n.is<language::lesser_or_eq_op>()) {
    set_binary_operator_processor(n, language::lesser_or_eq_op{});
  } else if (n.is<language::eqeq_op>()) {
    set_binary_operator_processor(n, language::eqeq_op{});
  } else if (n.is<language::not_eq_op>()) {
    set_binary_operator_processor(n, language::not_eq_op{});
  } else {
    throw parse_error("unexpected error: undefined binary operator", std::vector{n.begin()});
  }
}

}   // namespace language
