#include <catch2/catch.hpp>

#include <ASTNodeValueBuilder.hpp>

#include <ASTBuilder.hpp>
#include <ASTNodeDataTypeBuilder.hpp>

#include <ASTNodeDeclarationToAffectationConverter.hpp>
#include <ASTNodeTypeCleaner.hpp>

#include <ASTNodeExpressionBuilder.hpp>

#include <ASTNodeUnaryOperatorExpressionBuilder.hpp>

#include <ASTSymbolTableBuilder.hpp>

#include <ASTPrinter.hpp>

#include <Demangle.hpp>

#include <PEGGrammar.hpp>

#define CHECK_AST(data, expected_output)                                                            \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or      \
                  std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>);            \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    ASTSymbolTableBuilder{*ast};                                                                    \
    ASTNodeDataTypeBuilder{*ast};                                                                   \
    ASTNodeValueBuilder{*ast};                                                                      \
                                                                                                    \
    ASTNodeDeclarationToAffectationConverter{*ast};                                                 \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                                \
                                                                                                    \
    ASTNodeExpressionBuilder{*ast};                                                                 \
                                                                                                    \
    std::stringstream ast_output;                                                                   \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \
                                                                                                    \
    REQUIRE(ast_output.str() == expected_output);                                                   \
  }

TEST_CASE("ASTNodeUnaryOperatorExpressionBuilder", "[language]")
{
  SECTION("unary minus")
  {
    SECTION("B")
    {
      std::string_view data = R"(
B b;
-b;
-true;
-false;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, bool>)
 |   `-(language::name:b:NameProcessor)
 +-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, bool>)
 |   `-(language::true_kw:FakeProcessor)
 `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, bool>)
     `-(language::false_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("N")
    {
      std::string_view data = R"(
N n;
-n;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, unsigned long>)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
Z i;
-i;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
R x;
-x;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, double, double>)
     `-(language::name:x:NameProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("not")
  {
    SECTION("B")
    {
      std::string_view data = R"(
B b;
not b;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_not:UnaryExpressionProcessor<language::unary_not, bool, bool>)
     `-(language::name:b:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("N")
    {
      std::string_view data = R"(
N n;
not n;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_not:UnaryExpressionProcessor<language::unary_not, bool, unsigned long>)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
not 3;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_not:UnaryExpressionProcessor<language::unary_not, bool, long>)
     `-(language::integer:3:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
not 3.5;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::unary_not:UnaryExpressionProcessor<language::unary_not, bool, double>)
     `-(language::real:3.5:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Errors")
  {
    SECTION("Invalid unary operator")
    {
      auto ast = std::make_unique<ASTNode>();

      REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "unexpected error: undefined unary operator");
    }

    SECTION("Invalid unary operator")
    {
      auto ast = std::make_unique<ASTNode>();

      REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "unexpected error: undefined unary operator");
    }

    SECTION("Invalid value type for unary minus")
    {
      auto ast = std::make_unique<ASTNode>();
      ast->set_type<language::unary_minus>();
      ast->children.emplace_back(std::make_unique<ASTNode>());

      REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined value type for unary operator");
    }

    SECTION("Invalid value type for unary not")
    {
      auto ast = std::make_unique<ASTNode>();
      ast->set_type<language::unary_not>();
      ast->children.emplace_back(std::make_unique<ASTNode>());

      REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast}, "undefined value type for unary operator");
    }

    SECTION("Invalid data type for unary operator")
    {
      auto ast = std::make_unique<ASTNode>();
      ast->set_type<language::unary_minus>();
      ast->m_data_type = ASTNodeDataType::int_t;
      ast->children.emplace_back(std::make_unique<ASTNode>());

      REQUIRE_THROWS_WITH(ASTNodeUnaryOperatorExpressionBuilder{*ast},
                          "unexpected error: invalid operand type for unary operator");
    }
  }
}
