#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_all.hpp>

#include <language/utils/BuiltinFunctionEmbedder.hpp>

// clazy:excludeall=non-pod-global-static

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const double>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_const_double");

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const uint64_t>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_const_uint64_t");

TEST_CASE("BuiltinFunctionEmbedder", "[language]")
{
  rang::setControlMode(rang::control::Off);

  SECTION("math")
  {
    BuiltinFunctionEmbedder<double(double)> embedded_sin{[](double x) -> double { return std::sin(x); }};

    double arg              = 2;
    DataVariant arg_variant = arg;

    DataVariant result = embedded_sin.apply({arg_variant});

    REQUIRE(std::get<double>(result) == std::sin(arg));
    REQUIRE(embedded_sin.numberOfParameters() == 1);

    REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
    REQUIRE(embedded_sin.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
  }

  SECTION("multiple variant args")
  {
    std::function c = [](double x, uint64_t i) -> bool { return x > i; };

    BuiltinFunctionEmbedder<bool(double, uint64_t)> embedded_c{c};

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<DataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("R*R^2 -> R^2")
  {
    std::function c = [](double a, TinyVector<2> x) -> TinyVector<2> { return a * x; };

    BuiltinFunctionEmbedder<TinyVector<2>(double, TinyVector<2>)> embedded_c{c};

    double a_arg = 2.3;
    TinyVector<2> x_arg{3, 2};

    std::vector<DataVariant> args;
    args.push_back(a_arg);
    args.push_back(x_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t);
  }

  SECTION("R^2x2*R^2 -> R^2")
  {
    std::function c = [](TinyMatrix<2> A, TinyVector<2> x) -> TinyVector<2> { return A * x; };

    BuiltinFunctionEmbedder<TinyVector<2>(TinyMatrix<2>, TinyVector<2>)> embedded_c{c};

    TinyMatrix<2> a_arg{2.3, 1, -2, 3};
    TinyVector<2> x_arg{3, 2};

    std::vector<DataVariant> args;
    args.push_back(a_arg);
    args.push_back(x_arg);

    DataVariant result = embedded_c.apply(args);

    REQUIRE(std::get<TinyVector<2>>(result) == c(a_arg, x_arg));
    REQUIRE(embedded_c.numberOfParameters() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::vector_t);
    REQUIRE(embedded_c.getParameterDataTypes()[0] == ASTNodeDataType::matrix_t);
    REQUIRE(embedded_c.getParameterDataTypes()[1] == ASTNodeDataType::vector_t);
  }

  SECTION("POD BuiltinFunctionEmbedder")
  {
    std::function c = [](double x, uint64_t i) -> bool { return x > i; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<bool(double, uint64_t)>>(c);

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<DataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    DataVariant result = i_embedded_c->apply(args);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(i_embedded_c->numberOfParameters() == 2);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("void(double) BuiltinFunctionEmbedder")
  {
    double y = 1;

    std::function add_to_y = [&](double x) -> void { y += x; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<void(double)>>(add_to_y);

    double x = 0.5;
    i_embedded_c->apply(std::vector<DataVariant>{x});
    REQUIRE(y == 1.5);
    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::void_t);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector<EmbeddedData>{}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<EmbeddedData>>() + "\" to \"" + demangle<double>() + '"');
  }

  SECTION("EmbeddedData(double, double) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](double x, double y) -> std::shared_ptr<const double> {
      return std::make_shared<const double>(x + y);
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<const double>(double, double)>>(sum);

    // using 4ul enforces cast test
    REQUIRE(i_embedded_c->numberOfParameters() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::double_t);

    REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<const double>>);

    DataVariant result         = i_embedded_c->apply(std::vector<DataVariant>{2.3, 4ul});
    EmbeddedData embedded_data = std::get<EmbeddedData>(result);

    const IDataHandler& handled_data      = embedded_data.get();
    const DataHandler<const double>& data = dynamic_cast<const DataHandler<const double>&>(handled_data);
    REQUIRE(*data.data_ptr() == (2.3 + 4ul));
  }

  SECTION("double(std::shared_ptr<const double>) BuiltinFunctionEmbedder")
  {
    std::function abs = [&](std::shared_ptr<const double> x) -> double { return std::abs(*x); };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<double(std::shared_ptr<const double>)>>(abs);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ast_node_data_type_from<std::shared_ptr<const double>>);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);

    EmbeddedData data(std::make_shared<DataHandler<const double>>(std::make_shared<const double>(-2.3)));
    DataVariant result = i_embedded_c->apply({data});

    REQUIRE(std::get<double>(result) == 2.3);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({2.3}),
                        "unexpected error: unexpected argument types while casting: expecting EmbeddedData");

    EmbeddedData wrong_embedded_data_type(
      std::make_shared<DataHandler<const TinyVector<2>>>(std::make_shared<const TinyVector<2>>(-2, -2)));
    REQUIRE_THROWS_WITH(i_embedded_c->apply({wrong_embedded_data_type}),
                        "unexpected error: unexpected argument types while casting: "
                        "invalid EmbeddedData type, expecting " +
                          demangle<DataHandler<const double>>());
  }

  SECTION("uint64_t(std::vector<uint64_t>) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](const std::vector<uint64_t>& x) -> uint64_t {
      uint64_t sum = 0;
      for (size_t i = 0; i < x.size(); ++i) {
        sum += x[i];
      }
      return sum;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<uint64_t(const std::vector<uint64_t>&)>>(sum);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::unsigned_int_t);

    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({std::vector{1ul, 2ul, 3ul}})) == 6);
    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({std::vector{1ul, 2ul, 3ul, 4ul}})) == 10);

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector{1.2, 2.3, 3.1, 4.4}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<double>>() + "\" to \"" + demangle<std::vector<uint64_t>>() + '"');

    REQUIRE_THROWS_WITH(i_embedded_c->apply({std::vector<EmbeddedData>{}}),
                        "unexpected error: unexpected argument types while casting \"" +
                          demangle<std::vector<EmbeddedData>>() + "\" to \"" + demangle<std::vector<uint64_t>>() + '"');
  }

  SECTION("uint64_t(std::vector<EmbeddedData>) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](const std::vector<std::shared_ptr<const uint64_t>>& x) -> uint64_t {
      uint64_t sum = 0;
      for (size_t i = 0; i < x.size(); ++i) {
        sum += *x[i];
      }
      return sum;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<uint64_t(const std::vector<std::shared_ptr<const uint64_t>>&)>>(sum);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::unsigned_int_t);

    std::vector<EmbeddedData> embedded_data;
    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({embedded_data})) == 0);
    embedded_data.emplace_back(std::make_shared<DataHandler<const uint64_t>>(std::make_shared<const uint64_t>(1)));
    embedded_data.emplace_back(std::make_shared<DataHandler<const uint64_t>>(std::make_shared<const uint64_t>(2)));
    embedded_data.emplace_back(std::make_shared<DataHandler<const uint64_t>>(std::make_shared<const uint64_t>(3)));

    REQUIRE(std::get<uint64_t>(i_embedded_c->apply({embedded_data})) == 6);

    embedded_data.emplace_back(std::make_shared<DataHandler<const double>>(std::make_shared<const double>(4)));
    REQUIRE_THROWS_WITH(i_embedded_c->apply({embedded_data}),
                        "unexpected error: unexpected argument types while casting: invalid"
                        " EmbeddedData type, expecting " +
                          demangle<DataHandler<const uint64_t>>());

    REQUIRE_THROWS_WITH(i_embedded_c->apply({TinyVector<1>{13}}),
                        "unexpected error: unexpected argument types while casting \"" + demangle<TinyVector<1>>() +
                          "\" to \"" + demangle<std::vector<std::shared_ptr<const uint64_t>>>() + '"');
  }

  SECTION("double(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> double { return 1.5; };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<BuiltinFunctionEmbedder<double(void)>>(c);

    REQUIRE(1.5 == c());
    REQUIRE(std::get<double>(i_embedded_c->apply(std::vector<DataVariant>{})) == c());
    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::double_t);
  }

  SECTION("R*R -> R*R^2*shared_const_double BuiltinFunctionEmbedder")
  {
    std::function c = [](double a, double b) -> std::tuple<double, TinyVector<2>, std::shared_ptr<const double>> {
      return std::make_tuple(a + b, TinyVector<2>{b, -a}, std::make_shared<double>(a - b));
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<
      BuiltinFunctionEmbedder<std::tuple<double, TinyVector<2>, std::shared_ptr<const double>>(double, double)>>(c);

    const double a = 3.2;
    const double b = 1.5;

    REQUIRE(a + b == std::get<0>(c(a, b)));
    REQUIRE(TinyVector<2>{b, -a} == std::get<1>(c(a, b)));
    REQUIRE(a - b == *std::get<2>(c(a, b)));
    const AggregateDataVariant value_list =
      std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{a, b}));

    REQUIRE(std::get<double>(value_list[0]) == a + b);
    REQUIRE(std::get<TinyVector<2>>(value_list[1]) == TinyVector<2>{b, -a});
    auto data_type = i_embedded_c->getReturnDataType();
    REQUIRE(data_type == ASTNodeDataType::list_t);

    REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::double_t);
    REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::build<ASTNodeDataType::vector_t>(2));
    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<const double>>);
  }

  SECTION("(R^2) -> (R) BuiltinFunctionEmbedder")
  {
    std::function c = [](const std::vector<TinyVector<2>>& x) -> std::vector<double> {
      std::vector<double> sum(x.size());
      for (size_t i = 0; i < x.size(); ++i) {
        sum[i] = x[i][0] + x[i][1];
      }
      return sum;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::vector<double>(const std::vector<TinyVector<2>>&)>>(c);

    REQUIRE(i_embedded_c->numberOfParameters() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 1);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::tuple_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::tuple_t);

    std::vector<TinyVector<2>> x = {TinyVector<2>{1, 2}, TinyVector<2>{3, 4}, TinyVector<2>{-2, 4}};
    std::vector<double> result   = c(x);

    const std::vector value = std::get<std::vector<double>>(i_embedded_c->apply(std::vector<DataVariant>{x}));

    for (size_t i = 0; i < result.size(); ++i) {
      REQUIRE(value[i] == result[i]);
    }
  }

  SECTION("std::vector<EmbeddedData>(EmbeddedData, EmbeddedData) BuiltinFunctionEmbedder")
  {
    std::function sum = [&](const uint64_t& i, const uint64_t& j) -> std::vector<std::shared_ptr<const uint64_t>> {
      std::vector<std::shared_ptr<const uint64_t>> x = {std::make_shared<const uint64_t>(i),
                                                        std::make_shared<const uint64_t>(j)};
      return x;
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<
      BuiltinFunctionEmbedder<std::vector<std::shared_ptr<const uint64_t>>(const uint64_t&, const uint64_t&)>>(sum);

    REQUIRE(i_embedded_c->numberOfParameters() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 2);
    REQUIRE(i_embedded_c->getParameterDataTypes()[0] == ASTNodeDataType::unsigned_int_t);
    REQUIRE(i_embedded_c->getParameterDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::tuple_t);

    std::vector<uint64_t> values{3, 4};

    std::vector<EmbeddedData> embedded_data_list =
      std::get<std::vector<EmbeddedData>>(i_embedded_c->apply(std::vector<DataVariant>{values[0], values[1]}));

    for (size_t i = 0; i < embedded_data_list.size(); ++i) {
      const EmbeddedData& embedded_data       = embedded_data_list[i];
      const IDataHandler& handled_data        = embedded_data.get();
      const DataHandler<const uint64_t>& data = dynamic_cast<const DataHandler<const uint64_t>&>(handled_data);
      REQUIRE(*data.data_ptr() == values[i]);
    }
  }

  SECTION("void -> N*R*shared_const_double BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> std::tuple<uint64_t, double, std::shared_ptr<const double>> {
      uint64_t a = 1;
      double b   = 3.5;
      return std::make_tuple(a, b, std::make_shared<const double>(a + b));
    };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::tuple<uint64_t, double, std::shared_ptr<const double>>(void)>>(c);

    REQUIRE(1ul == std::get<0>(c()));
    REQUIRE(3.5 == std::get<1>(c()));
    REQUIRE((1ul + 3.5) == *std::get<2>(c()));
    const AggregateDataVariant value_list =
      std::get<AggregateDataVariant>(i_embedded_c->apply(std::vector<DataVariant>{}));

    REQUIRE(std::get<uint64_t>(value_list[0]) == 1ul);
    REQUIRE(std::get<double>(value_list[1]) == 3.5);

    auto data_type = i_embedded_c->getReturnDataType();
    REQUIRE(data_type == ASTNodeDataType::list_t);

    REQUIRE(*data_type.contentTypeList()[0] == ASTNodeDataType::unsigned_int_t);
    REQUIRE(*data_type.contentTypeList()[1] == ASTNodeDataType::double_t);
    REQUIRE(*data_type.contentTypeList()[2] == ast_node_data_type_from<std::shared_ptr<const double>>);
  }

  SECTION("void(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> void {};

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c = std::make_unique<BuiltinFunctionEmbedder<void(void)>>(c);

    REQUIRE_NOTHROW(std::get<std::monostate>(i_embedded_c->apply(std::vector<DataVariant>{})));
    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::void_t);
  }

  SECTION("EmbeddedData(void) BuiltinFunctionEmbedder")
  {
    std::function c = [](void) -> std::shared_ptr<const double> { return std::make_shared<const double>(1.5); };

    std::unique_ptr<IBuiltinFunctionEmbedder> i_embedded_c =
      std::make_unique<BuiltinFunctionEmbedder<std::shared_ptr<const double>(void)>>(c);

    REQUIRE(i_embedded_c->numberOfParameters() == 0);
    REQUIRE(i_embedded_c->getParameterDataTypes().size() == 0);

    REQUIRE(i_embedded_c->getReturnDataType() == ast_node_data_type_from<std::shared_ptr<const double>>);

    const auto embedded_data              = std::get<EmbeddedData>(i_embedded_c->apply(std::vector<DataVariant>{}));
    const IDataHandler& handled_data      = embedded_data.get();
    const DataHandler<const double>& data = dynamic_cast<const DataHandler<const double>&>(handled_data);
    REQUIRE(*data.data_ptr() == 1.5);
  }

  SECTION("error")
  {
    SECTION("invalid cast")
    {
      std::function positive = [](double x) -> bool { return x >= 0; };

      BuiltinFunctionEmbedder<bool(double)> embedded_positive{positive};

      std::string arg = std::string{"2.3"};

      std::vector<DataVariant> args;
      args.push_back(arg);

      std::ostringstream error_msg;
      error_msg << "unexpected error: unexpected argument types while casting \"" << demangle<std::string>()
                << "\" to \"" << demangle<double>() << "\"";

      REQUIRE_THROWS_WITH(embedded_positive.apply(args), error_msg.str());
    }

    SECTION("invalid arg type")
    {
      std::ostringstream error_msg;
      error_msg << "cannot bind C++ to language.\n"
                << "note: argument number 1 has no associated language type: ";
      error_msg << demangle<int>();
      REQUIRE_THROWS_WITH(BuiltinFunctionEmbedder<bool(int)>{[](int) -> bool { return 1; }}, error_msg.str());
    }

    SECTION("invalid arg type")
    {
      std::ostringstream error_msg;
      error_msg << "cannot bind C++ to language.\n"
                << "note: argument number 2 has no associated language type: ";
      error_msg << demangle<int>();
      REQUIRE_THROWS_WITH(BuiltinFunctionEmbedder<bool(double, int)>{[](int, int) -> bool { return 1; }},
                          error_msg.str());
    }

    SECTION("invalid return type")
    {
      std::ostringstream error_msg;
      error_msg << "cannot bind C++ to language.\n"
                << "note: return value has no associated language type: ";
      error_msg << demangle<std::shared_ptr<const int>>();

      std::function f = [](double) -> std::shared_ptr<const int> { return std::make_shared<int>(1); };

      REQUIRE_THROWS_WITH(BuiltinFunctionEmbedder<std::shared_ptr<const int>(double)>(f), error_msg.str());
    }

    SECTION("invalid return type in compound")
    {
      std::ostringstream error_msg;
      error_msg << "cannot bind C++ to language.\n"
                << "note: return value number 2 has no associated language type: ";
      error_msg << demangle<std::shared_ptr<const int>>();

      std::function f = [](double) -> std::tuple<double, std::shared_ptr<const int>> {
        return std::make_tuple(double{1.3}, std::make_shared<const int>(1));
      };

      REQUIRE_THROWS_WITH((BuiltinFunctionEmbedder<std::tuple<double, std::shared_ptr<const int>>(double)>{
                            [](double) -> std::tuple<double, std::shared_ptr<const int>> {
                              return std::make_tuple(2.3, std::make_shared<const int>(1));
                            }}),
                          error_msg.str());
    }
  }
}