#include <catch2/catch.hpp>

#include <CFunctionEmbedder.hpp>

TEST_CASE("CFunctionEmbedder", "[language]")
{
  rang::setControlMode(rang::control::Off);

  SECTION("math")
  {
    CFunctionEmbedder<double, double> embedded_sin{
      std::function<double(double)>{[](double x) -> double { return std::sin(x); }}};

    double arg                     = 2;
    ASTNodeDataVariant arg_variant = arg;

    ASTNodeDataVariant result;

    embedded_sin.apply({arg_variant}, result);

    REQUIRE(std::get<double>(result) == std::sin(arg));
    REQUIRE(embedded_sin.numberOfArguments() == 1);

    REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
    REQUIRE(embedded_sin.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
  }

  SECTION("multiple variant args")
  {
    std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };

    CFunctionEmbedder<bool, double, uint64_t> embedded_c{c};

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<ASTNodeDataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    ASTNodeDataVariant result;

    embedded_c.apply(args, result);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(embedded_c.numberOfArguments() == 2);

    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(embedded_c.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(embedded_c.getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("ICFunctionEmbedder")
  {
    std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };

    std::unique_ptr<ICFunctionEmbedder> i_embedded_c = std::make_unique<CFunctionEmbedder<bool, double, uint64_t>>(c);

    double d_arg   = 2.3;
    uint64_t i_arg = 3;

    std::vector<ASTNodeDataVariant> args;
    args.push_back(d_arg);
    args.push_back(i_arg);

    ASTNodeDataVariant result;

    i_embedded_c->apply(args, result);

    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
    REQUIRE(i_embedded_c->numberOfArguments() == 2);

    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
    REQUIRE(i_embedded_c->getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
    REQUIRE(i_embedded_c->getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
  }

  SECTION("error")
  {
    std::function<bool(double)> positive = [](double x) -> bool { return x >= 0; };

    CFunctionEmbedder<bool, double> embedded_positive{positive};

    std::string arg = std::string{"2.3"};

    std::vector<ASTNodeDataVariant> args;
    args.push_back(arg);

    ASTNodeDataVariant result;

    REQUIRE_THROWS(embedded_positive.apply(args, result));
  }
}
