#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <language/utils/BuiltinFunctionEmbedder.hpp>

// clazy:excludeall=non-pod-global-static

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_const_double");

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<double>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_double");

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<uint64_t>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_uint64_t");

TEST_CASE("BuiltinFunctionEmbedder", "[language]")
{
  rang::setControlMode(rang::control::Off);

  SECTION("math")
  {
    BuiltinFunctionEmbedder<double(double)> embedded_sin{[](double x) -> double { return std::sin(x); }};

    double arg              = 2;
    DataVariant arg_variant = arg;

    DataVariant result = embedded_sin.apply({arg_variant});

    REQUIRE(std::get<double>(result) == std::sin(arg));
    REQUIRE(embedded_sin.numberOfParameters() == 1);

    REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
    REQUIRE(embedded_sin.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
  }

  SECTION("multiple variant args")
  {
    std::function c = [](double x, uint64_t i) -> bool { return x > i; };

    BuiltinFunctionEmbedder<bool(double, uint64_t)> embedded_c{c};

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<DataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("R*R^2 -> R^2")
  {
    std::function c = [](double a, TinyVector<2> x) -> TinyVector<2> { return a * x; };

    BuiltinFunctionEmbedder<TinyVector<2>(double, TinyVector<2>)> embedded_c{c};

    double a_arg = 2.3;
    TinyVector<2> x_arg{3, 2};

    std::vector<DataVariant> args;
    args.push_back(a_arg);
    args.push_back(x_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t);
  }

  SECTION("R^2x2*R^2 -> R^2")
  {
    std::function c = [](TinyMatrix<2> A, TinyVector<2> x) -> TinyVector<2> { return A * x; };

    BuiltinFunctionEmbedder<TinyVector<2>(TinyMatrix<2>, TinyVector<2>)> embedded_c{c};

    TinyMatrix<2> a_arg = {2.3, 1, -2, 3};
    TinyVector<2> x_arg{3, 2};

    std::vector<DataVariant> args;
    args.push_back(a_arg);
    args.push_back(x_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::matrix_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t);
  }

  SECTION("POD BuiltinFunctionEmbedder")
  {
    std::function c = [](double x, uint64_t i) -> bool { return x > i; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<bool(double, uint64_t)>>(c);

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<DataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    DataVariant result = i_embedded_c->apply(args);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(i_embedded_c->numberOfParameters() == 2);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("void(double) BuiltinFunctionEmbedder")
  {
    double y = 1;

    std::function add_to_y = [&](double x) -> void { y += x; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<void(double)>>(add_to_y);

    double x = 0.5;
    i_embedded_c->apply(std::vector<DataVariant>{x});
    REQUIRE(y == 1.5);
    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::void_t);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector<EmbeddedData>{}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<EmbeddedData>>() + "\" to \"" + demangle<double>() + '"');
  }

  SECTION("EmbeddedData(double, double) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](double x, double y) -> std::shared_ptr<const double> {
      return std::make_shared<const double>(x + y);
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<const double>(double, double)>>(sum);

    // using 4ul enforces cast test
    REQUIRE(i_embedded_c->numberOfParameters() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::double_t);

    REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<const double>>);

    DataVariant result         = i_embedded_c->apply(std::vector<DataVariant>{2.3, 4ul});
    EmbeddedData embedded_data = std::get<EmbeddedData>(result);

    const IDataHandler& handled_data      = embedded_data.get();
    const DataHandler<const double>& data = dynamic_cast<const DataHandler<const double>&>(handled_data);
    REQUIRE(*data.data_ptr() == (2.3 + 4ul));
  }

  SECTION("double(std::shared_ptr<double>) BuiltinFunctionEmbedder")
  {
    std::function abs = [&](std::shared_ptr<const double> x) -> double { return std::abs(*x); };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<double(std::shared_ptr<const double>)>>(abs);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ast_node_data_type_from<std::shared_ptr<const double>>);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);

    EmbeddedData data(std::make_shared<DataHandler<const double>>(std::make_shared<const double>(-2.3)));
    DataVariant result = i_embedded_c->apply({data});

    REQUIRE(std::get<double>(result) == 2.3);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({2.3}),
                        "unexpected error: unexpected argument types while casting: expecting EmbeddedData");

    EmbeddedData wrong_embedded_data_type(
      std::make_shared<DataHandler<const TinyVector<2>>>(std::make_shared<const TinyVector<2>>(-2, -2)));
    REQUIRE_THROWS_WITH(i_embedded_c->apply({wrong_embedded_data_type}),
                        "unexpected error: unexpected argument types while casting: "
                        "invalid EmbeddedData type, expecting " +
                          demangle<DataHandler<const double>>());
  }

  SECTION("uint64_t(std::vector<uint64_t>) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](const std::vector<uint64_t>& x) -> uint64_t {
      uint64_t sum = 0;
      for (size_t i = 0; i < x.size(); ++i) {
        sum += x[i];
      }
      return sum;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<uint64_t(const std::vector<uint64_t>&)>>(sum);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::unsigned_int_t);

    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({std::vector{1ul, 2ul, 3ul}})) == 6);
    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({std::vector{1ul, 2ul, 3ul, 4ul}})) == 10);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector{1.2, 2.3, 3.1, 4.4}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<double>>() + "\" to \"" + demangle<std::vector<uint64_t>>() + '"');

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector<EmbeddedData>{}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<EmbeddedData>>() + "\" to \"" + demangle<std::vector<uint64_t>>() + '"');
  }

  SECTION("uint64_t(std::vector<EmbeddedData>) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](const std::vector<std::shared_ptr<uint64_t>>& x) -> uint64_t {
      uint64_t sum = 0;
      for (size_t i = 0; i < x.size(); ++i) {
        sum += *x[i];
      }
      return sum;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<uint64_t(const std::vector<std::shared_ptr<uint64_t>>&)>>(sum);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::unsigned_int_t);

    std::vector<EmbeddedData> embedded_data;
    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({embedded_data})) == 0);
    embedded_data.emplace_back(std::make_shared<DataHandler<uint64_t>>(std::make_shared<uint64_t>(1)));
    embedded_data.emplace_back(std::make_shared<DataHandler<uint64_t>>(std::make_shared<uint64_t>(2)));
    embedded_data.emplace_back(std::make_shared<DataHandler<uint64_t>>(std::make_shared<uint64_t>(3)));

    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({embedded_data})) == 6);

    embedded_data.emplace_back(std::make_shared<DataHandler<double>>(std::make_shared<double>(4)));
    REQUIRE_THROWS_WITH(i_embedded_c->apply({embedded_data}),
                        "unexpected error: unexpected argument types while casting: invalid"
                        " EmbeddedData type, expecting " +
                          demangle<DataHandler<uint64_t>>());

    REQUIRE_THROWS_WITH(i_embedded_c->apply({TinyVector<1>{13}}),
                        "unexpected error: unexpected argument types while casting \"" + demangle<TinyVector<1>>() +
                          "\" to \"" + demangle<std::vector<std::shared_ptr<uint64_t>>>() + '"');
  }

  SECTION("double(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> double { return 1.5; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<BuiltinFunctionEmbedder<double(void)>>(c);

    REQUIRE(1.5 == c());
    REQUIRE(std::get<double>(i_embedded_c->apply(std::vector<DataVariant>{})) == c());
    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);
  }

  SECTION("R*R -> R*R^2*shared_double BuiltinFunctionEmbedder")
  {
    std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<double>> {
      return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b));
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<
      BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<double>>(double, double)>>(c);

    const double a = 3.2;
    const double b = 1.5;

    REQUIRE(a + b == std::get<0>(c(a, b)));
    REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b)));
    REQUIRE(a - b == *std::get<2>(c(a, b)));
    const AggregateDataVariant value_list =
      std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b}));

    REQUIRE(std::get<double>(value_list[0]) == a + b);
    REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a});
    auto data_type = i_embedded_c->getReturnDataType();
    REQUIRE(data_type == ASTNodeDataType::list_t);

    REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t);
    REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2));
    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
  }

  SECTION("void -> N*R*shared_double BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<double>> {
      uint64_t a = 1;
      double b   = 3.5;
      return std::make_tuple(a, b, std::make_shared<double>(a + b));
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<double>>(void)>>(c);

    REQUIRE(1ul == std::get<0>(c()));
    REQUIRE(3.5 == std::get<1>(c()));
    REQUIRE((1ul + 3.5) == *std::get<2>(c()));
    const AggregateDataVariant value_list =
      std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{}));

    REQUIRE(std::get<uint64_t>(value_list[0]) == 1ul);
    REQUIRE(std::get<double>(value_list[1]) == 3.5);

    auto data_type = i_embedded_c->getReturnDataType();
    REQUIRE(data_type == ASTNodeDataType::list_t);

    REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t);
    REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t);
    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<double>>);
  }

  SECTION("void(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> void {};

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<BuiltinFunctionEmbedder<void(void)>>(c);

    REQUIRE_NOTHROW(std::get<std::monostate>(i_embedded_c->apply(std::vector<DataVariant>{})));
    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::void_t);
  }

  SECTION("EmbeddedData(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> std::shared_ptr<double> { return std::make_shared<double>(1.5); };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<double>(void)>>(c);

    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<double>>);

    const auto embedded_data         = std::get<EmbeddedData>(i_embedded_c->apply(std::vector<DataVariant>{}));
    const IDataHandler& handled_data = embedded_data.get();
    const DataHandler<double>& data  = dynamic_cast<const DataHandler<double>&>(handled_data);
    REQUIRE(*data.data_ptr() == 1.5);
  }

  SECTION("error")
  {
    std::function positive = [](double x) -> bool { return x >= 0; };

    BuiltinFunctionEmbedder<bool(double)> embedded_positive{positive};

    std::string arg = std::string{"2.3"};

    std::vector<DataVariant> args;
    args.push_back(arg);

    REQUIRE_THROWS(embedded_positive.apply(args));
  }
}