#include <catch2/catch.hpp>

#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDataTypeFlattener.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>

#include <pegtl/string_input.hpp>

// clazy:excludeall=non-pod-global-static

TEST_CASE("ASTNodeDataTypeFlattener", "[language]")
{
  SECTION("simple types")
  {
    SECTION("B")
    {
      std::string_view data = R"(
let b : B, b = true;
b;
)";

      string_input input{data, "test.pgs"};
      auto root_node = ASTBuilder::build(input);

      ASTSymbolTableBuilder{*root_node};
      ASTNodeDataTypeBuilder{*root_node};

      REQUIRE(root_node->children[1]->is_type<language::name>());

      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{*root_node->children[1], flattened_datatype_list};

      REQUIRE(flattened_datatype_list.size() == 1);
      REQUIRE(flattened_datatype_list[0].m_data_type == ASTNodeDataType::bool_t);
      REQUIRE(&flattened_datatype_list[0].m_parent_node == root_node->children[1].get());
    }

    SECTION("N")
    {
      std::string_view data = R"(
let n : N;
n;
)";

      string_input input{data, "test.pgs"};
      auto root_node = ASTBuilder::build(input);

      ASTSymbolTableBuilder{*root_node};
      ASTNodeDataTypeBuilder{*root_node};

      REQUIRE(root_node->children[1]->is_type<language::name>());

      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{*root_node->children[1], flattened_datatype_list};

      REQUIRE(flattened_datatype_list.size() == 1);
      REQUIRE(flattened_datatype_list[0].m_data_type == ASTNodeDataType::unsigned_int_t);
      REQUIRE(&flattened_datatype_list[0].m_parent_node == root_node->children[1].get());
    }
  }

  SECTION("Compound types")
  {
    SECTION("function evaluation -> N")
    {
      std::string_view data = R"(
let f: N -> N, n -> n;
f(2);
)";

      string_input input{data, "test.pgs"};
      auto root_node = ASTBuilder::build(input);

      ASTSymbolTableBuilder{*root_node};
      ASTNodeDataTypeBuilder{*root_node};

      // optimizations
      ASTNodeDeclarationToAffectationConverter{*root_node};

      ASTNodeTypeCleaner<language::var_declaration>{*root_node};
      ASTNodeTypeCleaner<language::fct_declaration>{*root_node};

      REQUIRE(root_node->children[0]->is_type<language::function_evaluation>());

      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{*root_node->children[0], flattened_datatype_list};

      REQUIRE(flattened_datatype_list.size() == 1);
      REQUIRE(flattened_datatype_list[0].m_data_type == ASTNodeDataType::unsigned_int_t);
      REQUIRE(&flattened_datatype_list[0].m_parent_node == root_node->children[0].get());
    }

    SECTION("function evaluation -> N*R*B*string*Z")
    {
      std::string_view data = R"(
let f: N -> N*R*B*string*Z, n -> (n, 0.5*n, n>2, n, 3-n);
f(2);
)";

      string_input input{data, "test.pgs"};
      auto root_node = ASTBuilder::build(input);

      ASTSymbolTableBuilder{*root_node};
      ASTNodeDataTypeBuilder{*root_node};

      // optimizations
      ASTNodeDeclarationToAffectationConverter{*root_node};

      ASTNodeTypeCleaner<language::var_declaration>{*root_node};
      ASTNodeTypeCleaner<language::fct_declaration>{*root_node};

      REQUIRE(root_node->children[0]->is_type<language::function_evaluation>());

      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{*root_node->children[0], flattened_datatype_list};

      REQUIRE(flattened_datatype_list.size() == 5);
      REQUIRE(flattened_datatype_list[0].m_data_type == ASTNodeDataType::unsigned_int_t);
      REQUIRE(flattened_datatype_list[1].m_data_type == ASTNodeDataType::double_t);
      REQUIRE(flattened_datatype_list[2].m_data_type == ASTNodeDataType::bool_t);
      REQUIRE(flattened_datatype_list[3].m_data_type == ASTNodeDataType::string_t);
      REQUIRE(flattened_datatype_list[4].m_data_type == ASTNodeDataType::int_t);
    }

    SECTION("function evaluation -> R*R^3")
    {
      std::string_view data = R"(
let f: R -> R*R^3, x -> (0.5*x, (x, x+1, x-1));
f(2);
)";

      string_input input{data, "test.pgs"};
      auto root_node = ASTBuilder::build(input);

      ASTSymbolTableBuilder{*root_node};
      ASTNodeDataTypeBuilder{*root_node};

      // optimizations
      ASTNodeDeclarationToAffectationConverter{*root_node};

      ASTNodeTypeCleaner<language::var_declaration>{*root_node};
      ASTNodeTypeCleaner<language::fct_declaration>{*root_node};

      REQUIRE(root_node->children[0]->is_type<language::function_evaluation>());

      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{*root_node->children[0], flattened_datatype_list};

      REQUIRE(flattened_datatype_list.size() == 2);
      REQUIRE(flattened_datatype_list[0].m_data_type == ASTNodeDataType::double_t);
      REQUIRE(flattened_datatype_list[1].m_data_type == ASTNodeDataType::vector_t);
      REQUIRE(flattened_datatype_list[1].m_data_type.dimension() == 3);
    }
  }
}
