#include <catch2/catch.hpp>

#include <ASTBuilder.hpp>
#include <ASTModulesImporter.hpp>
#include <ASTNodeExpressionBuilder.hpp>
#include <ASTNodeTypeCleaner.hpp>

#include <SymbolTable.hpp>

#include <PEGGrammar.hpp>

#include <ASTPrinter.hpp>

#define CHECK_AST(data, expected_output)                                                            \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view>);       \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    ASTModulesImporter{*ast};                                                                       \
    ASTNodeTypeCleaner<language::import_instruction>{*ast};                                         \
                                                                                                    \
    ASTNodeExpressionBuilder{*ast};                                                                 \
    ExecUntilBreakOrContinue exec_policy;                                                           \
    ast->execute(exec_policy);                                                                      \
                                                                                                    \
    std::stringstream ast_output;                                                                   \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::data_type}}; \
                                                                                                    \
    REQUIRE(ast_output.str() == expected_output);                                                   \
  }

TEST_CASE("ASTModulesImporter", "[language]")
{
  SECTION("no module")
  {
    std::string_view data = R"(
)";

    std::string_view result = R"(
(root:undefined)
)";

    CHECK_AST(data, result);
  }

  SECTION("module instruction removal")
  {
    std::string_view data = R"(
import math;
)";

    std::string_view result = R"(
(root:undefined)
)";

    CHECK_AST(data, result);
  }

  SECTION("module multiple import")
  {
    std::string_view data = R"(
import math;
import math;
)";

    std::string_view result = R"(
(root:undefined)
)";

    CHECK_AST(data, result);
  }

  SECTION("error")
  {
    SECTION("unknown module")
    {
      std::string_view data = R"(
import unknown_module;
)";

      string_input input{data, "test.pgs"};
      auto ast = ASTBuilder::build(input);

      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, parse_error);
    }

    SECTION("symbol already defined")
    {
      std::string_view data = R"(
import math;
)";

      string_input input{data, "test.pgs"};
      auto ast = ASTBuilder::build(input);

      ast->m_symbol_table->add("sin", ast->begin());

      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, parse_error);
    }
  }
}
