#include <catch2/catch.hpp>

#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <utils/Demangle.hpp>

#include <pegtl/string_input.hpp>

#include <sstream>

#define CHECK_AFFECTATION_RESULT(data, variable_name, expected_value)         \
  {                                                                           \
    string_input input{data, "test.pgs"};                                     \
    auto ast = ASTBuilder::build(input);                                      \
                                                                              \
    ASTSymbolTableBuilder{*ast};                                              \
    ASTNodeDataTypeBuilder{*ast};                                             \
                                                                              \
    ASTNodeDeclarationToAffectationConverter{*ast};                           \
    ASTNodeTypeCleaner<language::var_declaration>{*ast};                      \
                                                                              \
    ASTNodeExpressionBuilder{*ast};                                           \
    ExecutionPolicy exec_policy;                                              \
    ast->execute(exec_policy);                                                \
                                                                              \
    auto symbol_table = ast->m_symbol_table;                                  \
                                                                              \
    using namespace TAO_PEGTL_NAMESPACE;                                      \
    position use_position{internal::iterator{"fixture"}, "fixture"};          \
    use_position.byte    = 10000;                                             \
    auto [symbol, found] = symbol_table->find(variable_name, use_position);   \
                                                                              \
    auto attributes = symbol->attributes();                                   \
    auto value      = std::get<decltype(expected_value)>(attributes.value()); \
                                                                              \
    REQUIRE(value == expected_value);                                         \
  }

#define CHECK_AFFECTATION_THROWS(data)                                                       \
  {                                                                                          \
    string_input input{data, "test.pgs"};                                                    \
    auto ast = ASTBuilder::build(input);                                                     \
                                                                                             \
    ASTSymbolTableBuilder{*ast};                                                             \
    ASTNodeDataTypeBuilder{*ast};                                                            \
                                                                                             \
    ASTNodeDeclarationToAffectationConverter{*ast};                                          \
    ASTNodeTypeCleaner<language::var_declaration>{*ast};                                     \
                                                                                             \
    REQUIRE_THROWS(ASTNodeExpressionBuilder{*ast},                                           \
                   Catch::Matchers::Contains("invalid operands to affectation expression")); \
  }

#define CHECK_AFFECTATION_THROWS_WITH(data, error_message)              \
  {                                                                     \
    string_input input{data, "test.pgs"};                               \
    auto ast = ASTBuilder::build(input);                                \
                                                                        \
    ASTSymbolTableBuilder{*ast};                                        \
    ASTNodeDataTypeBuilder{*ast};                                       \
                                                                        \
    ASTNodeDeclarationToAffectationConverter{*ast};                     \
    ASTNodeTypeCleaner<language::var_declaration>{*ast};                \
                                                                        \
    REQUIRE_THROWS_WITH(ASTNodeExpressionBuilder{*ast}, error_message); \
  }

// clazy:excludeall=non-pod-global-static

TEST_CASE("ListAffectationProcessor", "[language]")
{
  SECTION("ListAffectations")
  {
    SECTION("R*R^2*string")
    {
      CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "x", double{1.2});
      CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "u",
                               (TinyVector<2>{2, 3}));
      CHECK_AFFECTATION_RESULT(R"(let (x,u,s): R*R^2*string, (x,u,s) = (1.2, (2,3), "foo");)", "s", std::string{"foo"});
    }

    SECTION("compound with string conversion")
    {
      CHECK_AFFECTATION_RESULT(R"(let z:R, z = 3; let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), z);)", "s",
                               std::to_string(double{3}));
      {
        std::ostringstream os;
        os << TinyVector<1>{7} << std::ends;
        CHECK_AFFECTATION_RESULT(R"(let v:R^1, v = 7; let  (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s",
                                 os.str());
      }
      {
        std::ostringstream os;
        os << TinyVector<2>{6, 3} << std::ends;
        CHECK_AFFECTATION_RESULT(R"(let v: R^2, v = (6,3); let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s",
                                 os.str());
      }
      {
        std::ostringstream os;
        os << TinyVector<3>{1, 2, 3} << std::ends;
        CHECK_AFFECTATION_RESULT(R"(let v:R^3, v = (1,2,3); let (x,u,s):R*R^2*string, (x,u,s) = (1.2, (2,3), v);)", "s",
                                 os.str());
      }
    }

    SECTION("compound R^d from '0'")
    {
      CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3*R^2*R^1, (x,y,z) = (0,0,0);)", "x", (TinyVector<3>{zero}));
      CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3*R^2*R^1, (x,y,z) = (0,0,0);)", "y", (TinyVector<2>{zero}));
      CHECK_AFFECTATION_RESULT(R"(let (x,y,z):R^3*R^2*R^1, (x,y,z) = (0,0,0);)", "z", (TinyVector<1>{zero}));
    }

    SECTION("compound with subscript values")
    {
      CHECK_AFFECTATION_RESULT(R"(let x:R^3; (x[0], x[2], x[1]) = (4, 6, 5);)", "x", (TinyVector<3>{4, 5, 6}));
      CHECK_AFFECTATION_RESULT(R"(let x:R^2; (x[1], x[0]) = (3, 6);)", "x", (TinyVector<2>{6, 3}));
      CHECK_AFFECTATION_RESULT(R"(let x:R^1; let y:R; (y, x[0]) = (4, 2.3);)", "x", (TinyVector<1>{2.3}));
    }
  }
}
