#include <catch2/catch.hpp>

#include <ASTNodeValueBuilder.hpp>

#include <ASTBuilder.hpp>
#include <ASTNodeDataTypeBuilder.hpp>

#include <ASTNodeDeclarationToAffectationConverter.hpp>
#include <ASTNodeTypeCleaner.hpp>

#include <ASTNodeExpressionBuilder.hpp>

#include <ASTNodeIncDecExpressionBuilder.hpp>

#include <ASTSymbolTableBuilder.hpp>

#include <ASTPrinter.hpp>

#include <Demangle.hpp>

#include <PEGGrammar.hpp>

#define CHECK_AST(data, expected_output)                                                            \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or      \
                  std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>);            \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    ASTSymbolTableBuilder{*ast};                                                                    \
    ASTNodeDataTypeBuilder{*ast};                                                                   \
    ASTNodeValueBuilder{*ast};                                                                      \
                                                                                                    \
    ASTNodeDeclarationToAffectationConverter{*ast};                                                 \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                                \
                                                                                                    \
    ASTNodeExpressionBuilder{*ast};                                                                 \
                                                                                                    \
    std::stringstream ast_output;                                                                   \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \
                                                                                                    \
    REQUIRE(ast_output.str() == expected_output);                                                   \
  }

#define DISALLOWED_CHAINED_AST(data, expected_error)                                          \
  {                                                                                           \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);            \
    static_assert(std::is_same_v<std::decay_t<decltype(expected_error)>, std::string_view> or \
                  std::is_same_v<std::decay_t<decltype(expected_error)>, std::string>);       \
                                                                                              \
    string_input input{data, "test.pgs"};                                                     \
    auto ast = ASTBuilder::build(input);                                                      \
                                                                                              \
    ASTSymbolTableBuilder{*ast};                                                              \
    ASTNodeDataTypeBuilder{*ast};                                                             \
    ASTNodeValueBuilder{*ast};                                                                \
                                                                                              \
    ASTNodeDeclarationToAffectationConverter{*ast};                                           \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                          \
                                                                                              \
    REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, expected_error);                      \
  }

TEST_CASE("ASTNodeIncDecExpressionBuilder", "[language]")
{
  SECTION("Pre-increment")
  {
    SECTION("N")
    {
      std::string_view data = R"(
N i=0;
++i;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::unary_plusplus:IncDecExpressionProcessor<language::unary_plusplus, unsigned long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
Z i=0;
++i;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::unary_plusplus:IncDecExpressionProcessor<language::unary_plusplus, long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
R x=0;
++x;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, double, long>)
 |   +-(language::name:x:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::unary_plusplus:IncDecExpressionProcessor<language::unary_plusplus, double>)
     `-(language::name:x:NameProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Pre-decrement")
  {
    SECTION("N")
    {
      std::string_view data = R"(
N i=1;
--i;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:1:FakeProcessor)
 `-(language::unary_minusminus:IncDecExpressionProcessor<language::unary_minusminus, unsigned long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
Z i=0;
--i;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::unary_minusminus:IncDecExpressionProcessor<language::unary_minusminus, long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
R x=2.3;
--x;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, double, double>)
 |   +-(language::name:x:NameProcessor)
 |   `-(language::real:2.3:FakeProcessor)
 `-(language::unary_minusminus:IncDecExpressionProcessor<language::unary_minusminus, double>)
     `-(language::name:x:NameProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Post-increment")
  {
    SECTION("N")
    {
      std::string_view data = R"(
N i=0;
i++;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::post_plusplus:IncDecExpressionProcessor<language::post_plusplus, unsigned long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
Z i=0;
i++;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::post_plusplus:IncDecExpressionProcessor<language::post_plusplus, long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
R x=0;
x++;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, double, long>)
 |   +-(language::name:x:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::post_plusplus:IncDecExpressionProcessor<language::post_plusplus, double>)
     `-(language::name:x:NameProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Post-decrement")
  {
    SECTION("N")
    {
      std::string_view data = R"(
N i=1;
i--;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, unsigned long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:1:FakeProcessor)
 `-(language::post_minusminus:IncDecExpressionProcessor<language::post_minusminus, unsigned long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z")
    {
      std::string_view data = R"(
Z i=0;
i--;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, long, long>)
 |   +-(language::name:i:NameProcessor)
 |   `-(language::integer:0:FakeProcessor)
 `-(language::post_minusminus:IncDecExpressionProcessor<language::post_minusminus, long>)
     `-(language::name:i:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R")
    {
      std::string_view data = R"(
R x=2.3;
x--;
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 +-(language::eq_op:AffectationProcessor<language::eq_op, double, double>)
 |   +-(language::name:x:NameProcessor)
 |   `-(language::real:2.3:FakeProcessor)
 `-(language::post_minusminus:IncDecExpressionProcessor<language::post_minusminus, double>)
     `-(language::name:x:NameProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Errors")
  {
    SECTION("Invalid operator type")
    {
      auto ast = std::make_unique<ASTNode>();
      REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast},
                          "unexpected error: undefined increment/decrement operator");
    }

    SECTION("Invalid operand type")
    {
      auto ast = std::make_unique<ASTNode>();
      ast->set_type<language::unary_plusplus>();
      ast->m_data_type = ASTNodeDataType::undefined_t;

      ast->children.emplace_back(std::make_unique<ASTNode>());

      REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast}, "invalid operand type for unary operator");
    }

    SECTION("Invalid data type")
    {
      auto ast = std::make_unique<ASTNode>();
      ast->set_type<language::unary_plusplus>();

      ast->children.emplace_back(std::make_unique<ASTNode>());
      ast->children[0]->set_type<language::name>();

      REQUIRE_THROWS_WITH(ASTNodeIncDecExpressionBuilder{*ast},
                          "unexpected error: undefined data type for unary operator");
    }

    SECTION("Not allowed chained ++/--")
    {
      SECTION("++ ++ a")
      {
        std::string_view data = R"(
1 ++ ++;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("++ -- a")
      {
        std::string_view data = R"(
1 ++ --;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("-- ++ a")
      {
        std::string_view data = R"(
1 -- ++;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("-- -- a")
      {
        std::string_view data = R"(
1 -- --;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("a ++ ++")
      {
        std::string_view data = R"(
++ ++ 1;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("a ++ --")
      {
        std::string_view data = R"(
++ -- 1;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("a -- ++")
      {
        std::string_view data = R"(
-- ++ 1;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("a -- --")
      {
        std::string_view data = R"(
-- -- 1;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("++ a ++")
      {
        std::string_view data = R"(
++ 1 ++;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("++ a --")
      {
        std::string_view data = R"(
++ 1 --;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("-- a ++")
      {
        std::string_view data = R"(
-- 1 ++;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }

      SECTION("-- --")
      {
        std::string_view data = R"(
-- 1 --;
)";

        std::string error_message = R"(chaining ++ or -- operators is not allowed)";

        DISALLOWED_CHAINED_AST(data, error_message)
      }
    }
  }
}
