#include <catch2/catch.hpp>

#include <ASTBuilder.hpp>
#include <ASTPrinter.hpp>
#include <sstream>

#define CHECK_AST(data, expected_output)                                                       \
  {                                                                                            \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);             \
    static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view>);  \
                                                                                               \
    string_input input{data, "test.pgs"};                                                      \
    auto ast = ASTBuilder::build(input);                                                       \
                                                                                               \
    std::stringstream ast_output;                                                              \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::none}}; \
                                                                                               \
    REQUIRE(ast_output.str() == expected_output);                                              \
  }

TEST_CASE("ASTBuilder", "[language]")
{
  rang::setControlMode(rang::control::Off);
  SECTION("AST parsing")
  {
    SECTION("declarations with init")
    {
      std::string_view data = R"(
N n = 2;
Z z = 3;
R r = 2.3e-5;
B b = false;
string s = "foo";
)";

      std::string_view result = R"(
(root)
 +-(language::declaration)
 |   +-(language::N_set)
 |   +-(language::name:n)
 |   `-(language::integer:2)
 +-(language::declaration)
 |   +-(language::Z_set)
 |   +-(language::name:z)
 |   `-(language::integer:3)
 +-(language::declaration)
 |   +-(language::R_set)
 |   +-(language::name:r)
 |   `-(language::real:2.3e-5)
 +-(language::declaration)
 |   +-(language::B_set)
 |   +-(language::name:b)
 |   `-(language::false_kw)
 `-(language::declaration)
     +-(language::string_type)
     +-(language::name:s)
     `-(language::literal:"foo")
)";
      CHECK_AST(data, result);
    }

    SECTION("affectations")
    {
      std::string_view data = R"(
N n; n = 2;
Z z; z = 3;
R r; r = 2.3e-5;
B b; b = false;
string s; s = "foo";
)";

      std::string_view result = R"(
(root)
 +-(language::declaration)
 |   +-(language::N_set)
 |   `-(language::name:n)
 +-(language::eq_op)
 |   +-(language::name:n)
 |   `-(language::integer:2)
 +-(language::declaration)
 |   +-(language::Z_set)
 |   `-(language::name:z)
 +-(language::eq_op)
 |   +-(language::name:z)
 |   `-(language::integer:3)
 +-(language::declaration)
 |   +-(language::R_set)
 |   `-(language::name:r)
 +-(language::eq_op)
 |   +-(language::name:r)
 |   `-(language::real:2.3e-5)
 +-(language::declaration)
 |   +-(language::B_set)
 |   `-(language::name:b)
 +-(language::eq_op)
 |   +-(language::name:b)
 |   `-(language::false_kw)
 +-(language::declaration)
 |   +-(language::string_type)
 |   `-(language::name:s)
 `-(language::eq_op)
     +-(language::name:s)
     `-(language::literal:"foo")
)";
      CHECK_AST(data, result);
    }

    SECTION("empty blocs simplification")
    {
      std::string_view data = R"(
{
  /* nothing but a bloc */
  {
    ; // nothing
  }
}
)";

      std::string_view result = R"(
(root)
)";
      CHECK_AST(data, result);
    }

    SECTION("operators precedence")
    {
      SECTION("basic operations")
      {
        std::string_view data = R"(
2+3.2*6 - 3.2/4;
)";

        std::string_view result = R"(
(root)
 `-(language::minus_op)
     +-(language::plus_op)
     |   +-(language::integer:2)
     |   `-(language::multiply_op)
     |       +-(language::real:3.2)
     |       `-(language::integer:6)
     `-(language::divide_op)
         +-(language::real:3.2)
         `-(language::integer:4)
)";
        CHECK_AST(data, result);
      }

      SECTION("parented expression")
      {
        std::string_view data = R"(
(2+3)*6;
)";

        std::string_view result = R"(
(root)
 `-(language::multiply_op)
     +-(language::plus_op)
     |   +-(language::integer:2)
     |   `-(language::integer:3)
     `-(language::integer:6)
)";
        CHECK_AST(data, result);
      }

      SECTION("all operators mix")
      {
        std::string_view data = R"(
1+2 and 3<= 2 * 4 - 1 == 2 or 2>=1 / 5 xor 7 and 2 or 2 <3 >7 xor -2 + true - not false;
)";

        std::string_view result = R"(
(root)
 `-(language::minus_op)
     +-(language::plus_op)
     |   +-(language::xor_op)
     |   |   +-(language::greater_op)
     |   |   |   +-(language::lesser_op)
     |   |   |   |   +-(language::or_op)
     |   |   |   |   |   +-(language::and_op)
     |   |   |   |   |   |   +-(language::xor_op)
     |   |   |   |   |   |   |   +-(language::greater_or_eq_op)
     |   |   |   |   |   |   |   |   +-(language::or_op)
     |   |   |   |   |   |   |   |   |   +-(language::eqeq_op)
     |   |   |   |   |   |   |   |   |   |   +-(language::minus_op)
     |   |   |   |   |   |   |   |   |   |   |   +-(language::lesser_or_eq_op)
     |   |   |   |   |   |   |   |   |   |   |   |   +-(language::and_op)
     |   |   |   |   |   |   |   |   |   |   |   |   |   +-(language::plus_op)
     |   |   |   |   |   |   |   |   |   |   |   |   |   |   +-(language::integer:1)
     |   |   |   |   |   |   |   |   |   |   |   |   |   |   `-(language::integer:2)
     |   |   |   |   |   |   |   |   |   |   |   |   |   `-(language::integer:3)
     |   |   |   |   |   |   |   |   |   |   |   |   `-(language::multiply_op)
     |   |   |   |   |   |   |   |   |   |   |   |       +-(language::integer:2)
     |   |   |   |   |   |   |   |   |   |   |   |       `-(language::integer:4)
     |   |   |   |   |   |   |   |   |   |   |   `-(language::integer:1)
     |   |   |   |   |   |   |   |   |   |   `-(language::integer:2)
     |   |   |   |   |   |   |   |   |   `-(language::integer:2)
     |   |   |   |   |   |   |   |   `-(language::divide_op)
     |   |   |   |   |   |   |   |       +-(language::integer:1)
     |   |   |   |   |   |   |   |       `-(language::integer:5)
     |   |   |   |   |   |   |   `-(language::integer:7)
     |   |   |   |   |   |   `-(language::integer:2)
     |   |   |   |   |   `-(language::integer:2)
     |   |   |   |   `-(language::integer:3)
     |   |   |   `-(language::integer:7)
     |   |   `-(language::unary_minus)
     |   |       `-(language::integer:2)
     |   `-(language::true_kw)
     `-(language::unary_not)
         `-(language::false_kw)
)";
        CHECK_AST(data, result);
      }
    }

    SECTION("unary operator simplification")
    {
      SECTION("multiple not")
      {
        std::string_view data = R"(
not not not not true;
not not not false;
)";

        std::string_view result = R"(
(root)
 +-(language::true_kw)
 `-(language::unary_not)
     `-(language::false_kw)
)";
        CHECK_AST(data, result);
      }

      SECTION("multiple unary plus")
      {
        std::string_view data = R"(
+ + + 3;
)";

        std::string_view result = R"(
(root)
 `-(language::integer:3)
)";
        CHECK_AST(data, result);
      }

      SECTION("multiple unary minus")
      {
        std::string_view data = R"(
- - + - - 3;
- + - - 2;
)";

        std::string_view result = R"(
(root)
 +-(language::integer:3)
 `-(language::unary_minus)
     `-(language::integer:2)
)";
        CHECK_AST(data, result);
      }

      SECTION("sums and unary plus/minus")
      {
        std::string_view data = R"(
1 - - 3;
1 + - 2;
4 - + 3;
)";

        std::string_view result = R"(
(root)
 +-(language::plus_op)
 |   +-(language::integer:1)
 |   `-(language::integer:3)
 +-(language::minus_op)
 |   +-(language::integer:1)
 |   `-(language::integer:2)
 `-(language::minus_op)
     +-(language::integer:4)
     `-(language::integer:3)
)";
        CHECK_AST(data, result);
      }
    }

    SECTION("post incr/decr rearrangements")
    {
      std::string_view data = R"(
1++;
2--;
)";

      std::string_view result = R"(
(root)
 +-(language::post_plusplus)
 |   `-(language::integer:1)
 `-(language::post_minusminus)
     `-(language::integer:2)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc simplification (one instruction per bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  a = 0;
} else {
  a =-1;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::eq_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     `-(language::eq_op)
         +-(language::name:a)
         `-(language::unary_minus)
             `-(language::integer:1)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc simplification (one instruction in first bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  a = 0;
} else {
  a = 3;
  a = a++;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::eq_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     `-(language::bloc)
         +-(language::eq_op)
         |   +-(language::name:a)
         |   `-(language::integer:3)
         `-(language::eq_op)
             +-(language::name:a)
             `-(language::post_plusplus)
                 `-(language::name:a)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc simplification (one instruction in second bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  a = 0;
  a++;
} else {
  a = 3;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::bloc)
     |   +-(language::eq_op)
     |   |   +-(language::name:a)
     |   |   `-(language::integer:0)
     |   `-(language::post_plusplus)
     |       `-(language::name:a)
     `-(language::eq_op)
         +-(language::name:a)
         `-(language::integer:3)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc non-simplification (one declaration in each bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  R b = a;
} else {
  R c = 2*a;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::bloc)
     |   `-(language::declaration)
     |       +-(language::R_set)
     |       +-(language::name:b)
     |       `-(language::name:a)
     `-(language::bloc)
         `-(language::declaration)
             +-(language::R_set)
             +-(language::name:c)
             `-(language::multiply_op)
                 +-(language::integer:2)
                 `-(language::name:a)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc simplification (one declaration in first bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  R b = a;
} else {
  R c = 2*a;
  ++a;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::bloc)
     |   `-(language::declaration)
     |       +-(language::R_set)
     |       +-(language::name:b)
     |       `-(language::name:a)
     `-(language::bloc)
         +-(language::declaration)
         |   +-(language::R_set)
         |   +-(language::name:c)
         |   `-(language::multiply_op)
         |       +-(language::integer:2)
         |       `-(language::name:a)
         `-(language::unary_plusplus)
             `-(language::name:a)
)";
      CHECK_AST(data, result);
    }

    SECTION("statement bloc simplification (one declaration in second bloc)")
    {
      std::string_view data = R"(
if (a > 0) {
  R b = a;
  ++b;
} else {
  R c = 2*a;
}
)";

      std::string_view result = R"(
(root)
 `-(language::if_statement)
     +-(language::greater_op)
     |   +-(language::name:a)
     |   `-(language::integer:0)
     +-(language::bloc)
     |   +-(language::declaration)
     |   |   +-(language::R_set)
     |   |   +-(language::name:b)
     |   |   `-(language::name:a)
     |   `-(language::unary_plusplus)
     |       `-(language::name:b)
     `-(language::bloc)
         `-(language::declaration)
             +-(language::R_set)
             +-(language::name:c)
             `-(language::multiply_op)
                 +-(language::integer:2)
                 `-(language::name:a)
)";
      CHECK_AST(data, result);
    }

    SECTION("for-statements simplification")
    {
      std::string_view data = R"(
for(N i=0; i<10; ++i) {
  i += 3;
}
)";

      std::string_view result = R"(
(root)
 `-(language::for_statement)
     +-(language::declaration)
     |   +-(language::N_set)
     |   +-(language::name:i)
     |   `-(language::integer:0)
     +-(language::lesser_op)
     |   +-(language::name:i)
     |   `-(language::integer:10)
     +-(language::unary_plusplus)
     |   `-(language::name:i)
     `-(language::pluseq_op)
         +-(language::name:i)
         `-(language::integer:3)
)";
      CHECK_AST(data, result);
    }

    SECTION("for-statements simplification (complex bloc)")
    {
      std::string_view data = R"(
for(N i=0; i<10; ++i) {
  i += 3;
  R j=i/5.;
}
)";

      std::string_view result = R"(
(root)
 `-(language::for_statement)
     +-(language::declaration)
     |   +-(language::N_set)
     |   +-(language::name:i)
     |   `-(language::integer:0)
     +-(language::lesser_op)
     |   +-(language::name:i)
     |   `-(language::integer:10)
     +-(language::unary_plusplus)
     |   `-(language::name:i)
     `-(language::for_statement_bloc)
         +-(language::pluseq_op)
         |   +-(language::name:i)
         |   `-(language::integer:3)
         `-(language::declaration)
             +-(language::R_set)
             +-(language::name:j)
             `-(language::divide_op)
                 +-(language::name:i)
                 `-(language::real:5.)
)";
      CHECK_AST(data, result);
    }

    SECTION("ostream simplifications")
    {
      std::string_view data = R"(
cout << 1+2 << "\n";
cerr << "error?\n";
clog << "log " << l << "\n";
)";

      std::string_view result = R"(
(root)
 +-(language::cout_kw)
 |   +-(language::plus_op)
 |   |   +-(language::integer:1)
 |   |   `-(language::integer:2)
 |   `-(language::literal:"\n")
 +-(language::cerr_kw)
 |   `-(language::literal:"error?\n")
 `-(language::clog_kw)
     +-(language::literal:"log ")
     +-(language::name:l)
     `-(language::literal:"\n")
)";
      CHECK_AST(data, result);
    }

    SECTION("syntax error")
    {
      std::string_view data = R"(
1+; // syntax error
)";

      string_input input{data, "test.pgs"};

      REQUIRE_THROWS_AS(ASTBuilder::build(input), parse_error);
    }
  }
}