#include <catch2/catch.hpp>

#include <ASTNodeValueBuilder.hpp>

#include <ASTBuilder.hpp>
#include <ASTNodeDataTypeBuilder.hpp>

#include <ASTModulesImporter.hpp>

#include <ASTNodeExpressionBuilder.hpp>
#include <ASTNodeFunctionEvaluationExpressionBuilder.hpp>
#include <ASTNodeFunctionExpressionBuilder.hpp>
#include <ASTNodeTypeCleaner.hpp>

#include <CFunctionEmbedder.hpp>

#include <ASTSymbolTableBuilder.hpp>

#include <ASTPrinter.hpp>

#include <PEGGrammar.hpp>

#include <Demangle.hpp>

#include <unordered_map>

namespace test_only
{
class CFunctionRegister
{
 private:
  std::unordered_map<std::string, std::shared_ptr<ICFunctionEmbedder>> m_name_cfunction_map;

  void
  _populateNameCFunctionMap()
  {
    m_name_cfunction_map.insert(
      std::make_pair("RtoR", std::make_shared<CFunctionEmbedder<double, double>>(
                               std::function<double(double)>{[](double x) -> double { return x + 1; }})));

    m_name_cfunction_map.insert(
      std::make_pair("ZtoR", std::make_shared<CFunctionEmbedder<double, int64_t>>(
                               std::function<double(int64_t)>{[](int64_t z) -> double { return 0.5 * z; }})));

    m_name_cfunction_map.insert(
      std::make_pair("NtoR", std::make_shared<CFunctionEmbedder<double, uint64_t>>(
                               std::function<double(uint64_t)>{[](uint64_t n) -> double { return 0.5 * n; }})));

    m_name_cfunction_map.insert(
      std::make_pair("BtoR", std::make_shared<CFunctionEmbedder<double, bool>>(
                               std::function<double(bool)>{[](bool b) -> double { return b; }})));

    m_name_cfunction_map.insert(std::make_pair("R2toB", std::make_shared<CFunctionEmbedder<bool, double, double>>(
                                                          std::function<bool(double, double)>{
                                                            [](double x, double y) -> bool { return x > y; }})));

    m_name_cfunction_map.insert(
      std::make_pair("StoB_invalid",
                     std::make_shared<CFunctionEmbedder<bool, std::string>>(
                       std::function<bool(std::string)>{[](std::string s) -> bool { return s.size() > 0; }})));
  }

 public:
  CFunctionRegister(ASTNode& root_node)
  {
    SymbolTable& symbol_table = *root_node.m_symbol_table;

    CFunctionEmbedderTable& c_function_embedder_table = symbol_table.cFunctionEbedderTable();

    this->_populateNameCFunctionMap();

    for (auto [symbol_name, c_function] : m_name_cfunction_map) {
      auto [i_symbol, success] = symbol_table.add(symbol_name, root_node.begin());

      if (not success) {
        std::ostringstream error_message;
        error_message << "cannot add symbol '" << symbol_name << "' it is already defined";
        throw parse_error(error_message.str(), root_node.begin());
      }

      i_symbol->attributes().setDataType(ASTNodeDataType::c_function_t);
      i_symbol->attributes().setIsInitialized();
      i_symbol->attributes().value() = c_function_embedder_table.size();

      c_function_embedder_table.add(c_function);
    }
  }
};
}   // namespace test_only

#define CHECK_AST(data, expected_output)                                                            \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert((std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view>) or    \
                  (std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>));          \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    test_only::CFunctionRegister{*ast};                                                             \
                                                                                                    \
    ASTSymbolTableBuilder{*ast};                                                                    \
    ASTNodeDataTypeBuilder{*ast};                                                                   \
    ASTNodeValueBuilder{*ast};                                                                      \
                                                                                                    \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                                \
    ASTNodeExpressionBuilder{*ast};                                                                 \
                                                                                                    \
    std::stringstream ast_output;                                                                   \
    ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \
                                                                                                    \
    REQUIRE(ast_output.str() == expected_output);                                                   \
  }

#define CHECK_AST_THROWS_WITH(data, expected_error)                                                 \
  {                                                                                                 \
    static_assert(std::is_same_v<std::decay_t<decltype(data)>, std::string_view>);                  \
    static_assert((std::is_same_v<std::decay_t<decltype(expected_error)>, std::string_view>) or     \
                  (std::is_same_v<std::decay_t<decltype(expected_error)>, std::string>));           \
                                                                                                    \
    string_input input{data, "test.pgs"};                                                           \
    auto ast = ASTBuilder::build(input);                                                            \
                                                                                                    \
    test_only::CFunctionRegister{*ast};                                                             \
                                                                                                    \
    ASTSymbolTableBuilder{*ast};                                                                    \
    ASTNodeDataTypeBuilder{*ast};                                                                   \
    ASTNodeValueBuilder{*ast};                                                                      \
                                                                                                    \
    ASTNodeTypeCleaner<language::declaration>{*ast};                                                \
    REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, Catch::Matchers::Contains(expected_error)); \
  }

TEST_CASE("ASTNodeCFunctionExpressionBuilder", "[language]")
{
  SECTION("R -> R")
  {
    SECTION("from R")
    {
      std::string_view data = R"(
RtoR(1.);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:RtoR:NameProcessor)
     `-(language::real:1.:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from Z")
    {
      std::string_view data = R"(
RtoR(1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:RtoR:NameProcessor)
     `-(language::integer:1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from N")
    {
      std::string_view data = R"(
N n = 1;
RtoR(n);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:RtoR:NameProcessor)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from B")
    {
      std::string_view data = R"(
RtoR(true);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:RtoR:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("Z -> R")
  {
    SECTION("from R")
    {
      std::string_view data = R"(
ZtoR(1.);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:ZtoR:NameProcessor)
     `-(language::real:1.:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from Z")
    {
      std::string_view data = R"(
ZtoR(1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:ZtoR:NameProcessor)
     `-(language::integer:1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from N")
    {
      std::string_view data = R"(
N n = 1;
ZtoR(n);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:ZtoR:NameProcessor)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from B")
    {
      std::string_view data = R"(
ZtoR(true);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:ZtoR:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("N -> R")
  {
    SECTION("from R")
    {
      std::string_view data = R"(
NtoR(1.);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:NtoR:NameProcessor)
     `-(language::real:1.:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from Z")
    {
      std::string_view data = R"(
NtoR(1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:NtoR:NameProcessor)
     `-(language::integer:1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from N")
    {
      std::string_view data = R"(
N n = 1;
NtoR(n);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:NtoR:NameProcessor)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from B")
    {
      std::string_view data = R"(
NtoR(true);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:NtoR:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("B -> R")
  {
    SECTION("from R")
    {
      std::string_view data = R"(
BtoR(1.);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:BtoR:NameProcessor)
     `-(language::real:1.:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from Z")
    {
      std::string_view data = R"(
BtoR(1);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:BtoR:NameProcessor)
     `-(language::integer:1:FakeProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from N")
    {
      std::string_view data = R"(
N n = 1;
BtoR(n);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:BtoR:NameProcessor)
     `-(language::name:n:NameProcessor)
)";

      CHECK_AST(data, result);
    }

    SECTION("from B")
    {
      std::string_view data = R"(
BtoR(true);
)";

      std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:BtoR:NameProcessor)
     `-(language::true_kw:FakeProcessor)
)";

      CHECK_AST(data, result);
    }
  }

  SECTION("R2 -> B")
  {
    std::string_view data = R"(
R2toB(1., 0.);
)";

    std::string_view result = R"(
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:CFunctionProcessor)
     +-(language::name:R2toB:NameProcessor)
     `-(language::function_argument_list:FakeProcessor)
         +-(language::real:1.:FakeProcessor)
         `-(language::real:0.:FakeProcessor)
)";

    CHECK_AST(data, result);
  }

  SECTION("errors")
  {
    SECTION("bad number of arguments")
    {
      std::string_view data = R"(
BtoR(true, false);
)";
      CHECK_AST_THROWS_WITH(data, std::string{"bad number of arguments:"});
    }

    SECTION("bad number of arguments 2")
    {
      std::string_view data = R"(
R2toB(3);
)";
      CHECK_AST_THROWS_WITH(data, std::string{"bad number of arguments:"});
    }

    SECTION("invalid argument type")
    {
      std::string_view data = R"(
RtoR("foo");
)";
      CHECK_AST_THROWS_WITH(data, std::string{"invalid argument type for function"});
    }

    SECTION("invalid function parameter type")
    {
      std::string_view data = R"(
StoB_invalid(3);
)";
      CHECK_AST_THROWS_WITH(data, std::string{"unexpected error: undefined parameter type"});
    }
  }
}
