#include <catch2/catch.hpp>

#include <language/ast/ASTNode.hpp>
#include <language/ast/ASTNodeDataType.hpp>

namespace language
{
struct integer;
struct real;
struct vector_type;
}   // namespace language

// clazy:excludeall=non-pod-global-static

TEST_CASE("ASTNodeDataType", "[language]")
{
  SECTION("dataTypeName")
  {
    REQUIRE(dataTypeName(ASTNodeDataType::undefined_t) == "undefined");
    REQUIRE(dataTypeName(ASTNodeDataType::bool_t) == "B");
    REQUIRE(dataTypeName(ASTNodeDataType::unsigned_int_t) == "N");
    REQUIRE(dataTypeName(ASTNodeDataType::int_t) == "Z");
    REQUIRE(dataTypeName(ASTNodeDataType::double_t) == "R");
    REQUIRE(dataTypeName(ASTNodeDataType::string_t) == "string");
    REQUIRE(dataTypeName(ASTNodeDataType::typename_t) == "typename");
    REQUIRE(dataTypeName(ASTNodeDataType::void_t) == "void");
    REQUIRE(dataTypeName(ASTNodeDataType::function_t) == "function");
    REQUIRE(dataTypeName(ASTNodeDataType::builtin_function_t) == "builtin_function");
    REQUIRE(dataTypeName(ASTNodeDataType::list_t) == "list");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}) ==
            "tuple(B)");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}) ==
            "tuple(N)");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}) ==
            "tuple(Z)");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}) ==
            "tuple(R)");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}) ==
            "tuple(R)");

    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::type_name_id_t, 1}) == "type_name_id");

    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::type_id_t, "user_type"}) == "user_type");

    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 1}) == "R^1");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 2}) == "R^2");
    REQUIRE(dataTypeName(ASTNodeDataType{ASTNodeDataType::vector_t, 3}) == "R^3");
  }

  SECTION("promotion")
  {
    REQUIRE(dataTypePromotion(ASTNodeDataType::undefined_t, ASTNodeDataType::undefined_t) ==
            ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::void_t, ASTNodeDataType::double_t) == ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::undefined_t) == ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::bool_t) == ASTNodeDataType::double_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::unsigned_int_t) == ASTNodeDataType::double_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::int_t) == ASTNodeDataType::double_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::unsigned_int_t) ==
            ASTNodeDataType::unsigned_int_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::bool_t) == ASTNodeDataType::int_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::bool_t) ==
            ASTNodeDataType::unsigned_int_t);

    REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::bool_t) == ASTNodeDataType::string_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::int_t) == ASTNodeDataType::string_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::unsigned_int_t) == ASTNodeDataType::string_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::string_t, ASTNodeDataType::double_t) == ASTNodeDataType::string_t);

    REQUIRE(dataTypePromotion(ASTNodeDataType::bool_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::int_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::string_t) ==
            ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::string_t) == ASTNodeDataType::undefined_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::bool_t, ASTNodeDataType::vector_t) == ASTNodeDataType::vector_t);
    REQUIRE(dataTypePromotion(ASTNodeDataType::double_t, ASTNodeDataType::vector_t) == ASTNodeDataType::vector_t);
  }

  SECTION("getVectorDataType")
  {
    std::unique_ptr type_node = std::make_unique<ASTNode>();
    type_node->set_type<language::vector_type>();

    type_node->emplace_back(std::make_unique<ASTNode>());

    {
      std::unique_ptr dimension_node = std::make_unique<ASTNode>();
      dimension_node->set_type<language::integer>();
      dimension_node->source  = "17";
      auto& source            = dimension_node->source;
      dimension_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source[0]};
      dimension_node->m_end   = TAO_PEGTL_NAMESPACE::internal::iterator{&source[source.size()]};
      type_node->emplace_back(std::move(dimension_node));
    }

    SECTION("good node")
    {
      REQUIRE(getVectorDataType(*type_node) == ASTNodeDataType::vector_t);
      REQUIRE(getVectorDataType(*type_node).dimension() == 17);
    }

    SECTION("bad node type")
    {
      type_node->set_type<language::integer>();
      REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected node type");
    }

    SECTION("bad children size 1")
    {
      type_node->children.clear();
      REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected node type");
    }

    SECTION("bad children size 1")
    {
      type_node->children.emplace_back(std::unique_ptr<ASTNode>());
      REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected node type");
    }

    SECTION("bad dimension type")
    {
      type_node->children[1]->set_type<language::real>();
      REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected non integer constant dimension");
    }
  }

  SECTION("isNaturalConversion")
  {
    SECTION("-> B")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::bool_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::bool_t));
    }

    SECTION("-> N")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::unsigned_int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::unsigned_int_t));
    }

    SECTION("-> Z")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::int_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::int_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::int_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::int_t));
    }

    SECTION("-> R")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::double_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::double_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::double_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::double_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::double_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::double_t));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::double_t));
    }

    SECTION("-> string")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType::string_t));
      REQUIRE(isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType::string_t));
    }

    SECTION("-> tuple")
    {
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t,
                                  ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}));
      REQUIRE(isNaturalConversion(ASTNodeDataType::bool_t,
                                  ASTNodeDataType{ASTNodeDataType::tuple_t,
                                                  ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}));
      REQUIRE(
        not isNaturalConversion(ASTNodeDataType::int_t,
                                ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::bool_t}}));

      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t,
                                  ASTNodeDataType{ASTNodeDataType::tuple_t,
                                                  ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}));
      REQUIRE(isNaturalConversion(ASTNodeDataType::unsigned_int_t,
                                  ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::int_t}}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t,
                                      ASTNodeDataType{ASTNodeDataType::tuple_t,
                                                      ASTNodeDataType{ASTNodeDataType::unsigned_int_t}}));

      REQUIRE(
        isNaturalConversion(ASTNodeDataType::bool_t,
                            ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}));
      REQUIRE(isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::tuple_t,
                                                                          ASTNodeDataType{ASTNodeDataType::double_t}}));
      REQUIRE(
        isNaturalConversion(ASTNodeDataType::unsigned_int_t,
                            ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}));
      REQUIRE(
        isNaturalConversion(ASTNodeDataType::double_t,
                            ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}));
      REQUIRE(
        not isNaturalConversion(ASTNodeDataType::string_t,
                                ASTNodeDataType{ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::double_t}}));
    }

    SECTION("-> vector")
    {
      REQUIRE(not isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType{ASTNodeDataType::vector_t, 2}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType{ASTNodeDataType::vector_t, 3}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::vector_t, 1}));

      REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1},
                                  ASTNodeDataType{ASTNodeDataType::vector_t, 1}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 1}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 1}));

      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 2}));
      REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2},
                                  ASTNodeDataType{ASTNodeDataType::vector_t, 2}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 2}));

      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 1},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 3}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 2},
                                      ASTNodeDataType{ASTNodeDataType::vector_t, 3}));
      REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::vector_t, 3},
                                  ASTNodeDataType{ASTNodeDataType::vector_t, 3}));
    }

    SECTION("-> type_id")
    {
      REQUIRE(not isNaturalConversion(ASTNodeDataType::bool_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(
        not isNaturalConversion(ASTNodeDataType::unsigned_int_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::int_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::double_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::string_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::vector_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));
      REQUIRE(not isNaturalConversion(ASTNodeDataType::tuple_t, ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));

      REQUIRE(isNaturalConversion(ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"},
                                  ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"}));

      REQUIRE(not isNaturalConversion(ASTNodeDataType{ASTNodeDataType::type_id_t, "foo"},
                                      ASTNodeDataType{ASTNodeDataType::type_id_t, "bar"}));
    }
  }
}
