#include <catch2/catch.hpp>

#include <ASTNodeValueBuilder.hpp>

#include <ASTBuilder.hpp>
#include <ASTNodeDataTypeBuilder.hpp>

#include <ASTModulesImporter.hpp>

#include <ASTNodeExpressionBuilder.hpp>
#include <ASTNodeFunctionEvaluationExpressionBuilder.hpp>
#include <ASTNodeFunctionExpressionBuilder.hpp>
#include <ASTNodeTypeCleaner.hpp>

#include <ASTSymbolTableBuilder.hpp>

#include <ASTPrinter.hpp>

#include <PEGGrammar.hpp>

#include <Demangle.hpp>

#define CHECK_AST(data, expected_output)                                                            \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert((std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view>) or    \
                  (std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>));          \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    ASTModulesImporter{*ast};                                                                       \
    ASTNodeTypeCleaner<language::import_instruction>{*ast};                                         \
                                                                                                    \
    ASTSymbolTableBuilder{*ast};                                                                    \
    ASTNodeDataTypeBuilder{*ast};                                                                   \
    ASTNodeValueBuilder{*ast};                                                                      \
                                                                                                    \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                                \
    ASTNodeTypeCleaner<language::let_declaration>{*ast};                                            \
    ASTNodeExpressionBuilder{*ast};                                                                 \
                                                                                                    \
    std::stringstream ast_output;                                                                   \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \
                                                                                                    \
    REQUIRE(ast_output.str() == expected_output);                                                   \
  }

#define CHECK_AST_THROWS(data)                                                     \
  {                                                                                \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>); \
                                                                                   \
    string_input input{data, "test.pgs"};                                          \
    auto ast = ASTBuilder::build(input);                                           \
                                                                                   \
    ASTModulesImporter{*ast};                                                      \
    ASTNodeTypeCleaner<language::import_instruction>{*ast};                        \
                                                                                   \
    ASTSymbolTableBuilder{*ast};                                                   \
    ASTNodeDataTypeBuilder{*ast};                                                  \
    ASTNodeValueBuilder{*ast};                                                     \
                                                                                   \
    ASTNodeTypeCleaner<language::declaration>{*ast};                               \
    ASTNodeTypeCleaner<language::let_declaration>{*ast};                           \
    REQUIRE_THROWS_AS(ASTNodeExpressionBuilder{*ast}, parse_error);                \
  }

TEST_CASE("ASTNodeFunctionExpressionBuilder", "[language]")
{
  SECTION("return a B")
  {
    SECTION("B argument")
    {
      SECTION("B parameter")
      {
        std::string_view data = R"(
let not_v : B -> B, a -> not a;
not_v(true);
)";

        std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:not_v:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

        CHECK_AST(data, result);
      }

      SECTION("N parameter")
      {
        std::string_view data = R"(
let not_v : B -> B, a -> not a;
N n = 1;
not_v(n);
)";

        std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:not_v:NameProcessor)
     `-(language::name:n:NameProcessor)
)";

        CHECK_AST(data, result);
      }

      SECTION("Z parameter")
      {
        std::string_view data = R"(
let not_v : B -> B, a -> not a;
not_v(-1);
)";

        std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:not_v:NameProcessor)
     `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>)
         `-(language::integer:1:FakeProcessor)
)";

        CHECK_AST(data, result);
      }

      SECTION("R parameter")
      {
        std::string_view data = R"(
let not_v : B -> B, a -> not a;
not_v(1.3);
)";

        std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:not_v:NameProcessor)
     `-(language::real:1.3:FakeProcessor)
)";

        CHECK_AST(data, result);
      }
    }

    SECTION("N argument")
    {
      std::string_view data = R"(
let test : N -> B, n -> n<10;
test(2);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:test:NameProcessor)
     `-(language::integer:2:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z argument")
    {
      std::string_view data = R"(
let test : Z -> B, z -> z>3;
test(2);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:test:NameProcessor)
     `-(language::integer:2:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("R argument")
    {
      std::string_view data = R"(
let test : R -> B, x -> x>2.3;
test(2.1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:test:NameProcessor)
     `-(language::real:2.1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("return a N")
  {
    SECTION("N argument")
    {
      std::string_view data = R"(
let test : N -> N, n -> n+2;
test(2.1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:test:NameProcessor)
     `-(language::real:2.1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z argument")
    {
      std::string_view data = R"(
let absolute : Z -> N, z -> (z>0)*z -(z<=0)*z;
absolute(-2);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:absolute:NameProcessor)
     `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>)
         `-(language::integer:2:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("return a Z")
  {
    SECTION("N argument")
    {
      std::string_view data = R"(
let minus : N -> Z, n -> -n;
minus(true);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:minus:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("Z argument")
    {
      std::string_view data = R"(
let times_2_3 : Z -> Z, z -> z*2.3;
times_2_3(-2);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:FunctionProcessor)
     +-(language::name:times_2_3:NameProcessor)
     `-(language::unary_minus:UnaryExpressionProcessor<language::unary_minus, long, long>)
         `-(language::integer:2:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("errors")
  {
    SECTION("wrong argument number")
    {
      std::string_view data = R"(
let Id : Z -> Z, z -> z;
Id(2,3);
)";
      CHECK_AST_THROWS(data);
    }

    SECTION("wrong argument number 2")
    {
      std::string_view data = R"(
let sum : R*R -> R, (x,y) -> x+y;
sum(2);
)";
      CHECK_AST_THROWS(data);
    }
  }
}
