#include <catch2/catch.hpp>

#include <CFunctionEmbedder.hpp>
#include <CFunctionEmbedderTable.hpp>

#include <CMathModule.hpp>

TEST_CASE("CMathModule", "[language]")
{
  rang::setControlMode(rang::control::Off);

  CMathModule math_module;
  const auto& name_cfunction = math_module.getNameCFunctionsMap();

  REQUIRE(name_cfunction.size() == 26);

  SECTION("double -> double")
  {
    double arg = 0.7;

    ASTNodeDataVariant arg_variant = arg;
    ASTNodeDataVariant result_variant;

    SECTION("sqrt")
    {
      auto i_function = name_cfunction.find("sqrt");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);
      auto result = std::sqrt(arg);

      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("abs")
    {
      auto i_function = name_cfunction.find("abs");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);
      auto result = std::abs(arg);

      REQUIRE(std::get<decltype(result)>(result_variant) == std::abs(arg));

      double arg = -3;

      ASTNodeDataVariant arg_variant = arg;

      function_embedder.apply({arg_variant}, result_variant);

      result = std::abs(arg);

      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("sin")
    {
      auto i_function = name_cfunction.find("sin");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::sin(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("cos")
    {
      auto i_function = name_cfunction.find("cos");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::cos(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("tan")
    {
      auto i_function = name_cfunction.find("tan");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::tan(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("asin")
    {
      auto i_function = name_cfunction.find("asin");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::asin(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("acos")
    {
      auto i_function = name_cfunction.find("acos");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::acos(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("atan")
    {
      auto i_function = name_cfunction.find("atan");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::atan(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("sinh")
    {
      arg         = 1.3;
      arg_variant = arg;

      auto i_function = name_cfunction.find("sinh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::sinh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("cosh")
    {
      auto i_function = name_cfunction.find("cosh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::cosh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("tanh")
    {
      auto i_function = name_cfunction.find("tanh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::tanh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("asinh")
    {
      auto i_function = name_cfunction.find("asinh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::asinh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("acosh")
    {
      arg         = 10;
      arg_variant = arg;

      auto i_function = name_cfunction.find("acosh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::acosh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("atanh")
    {
      auto i_function = name_cfunction.find("atanh");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::atanh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("exp")
    {
      auto i_function = name_cfunction.find("exp");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::exp(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("log")
    {
      auto i_function = name_cfunction.find("log");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::log(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("nearbyint")
    {
      auto i_function = name_cfunction.find("nearbyint");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::nearbyint(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("ceil")
    {
      auto i_function = name_cfunction.find("ceil");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::ceil(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("floor")
    {
      auto i_function = name_cfunction.find("floor");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::floor(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("trunc")
    {
      auto i_function = name_cfunction.find("trunc");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::trunc(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("round")
    {
      auto i_function = name_cfunction.find("round");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::round(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("rint")
    {
      auto i_function = name_cfunction.find("rint");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::rint(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }
  }

  SECTION("double -> int64_t")
  {
    double arg = 1.3;

    ASTNodeDataVariant arg_variant = arg;
    ASTNodeDataVariant result_variant;

    SECTION("lround")
    {
      auto i_function = name_cfunction.find("lround");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      auto result = std::lround(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("lrint")
    {
      auto i_function = name_cfunction.find("lrint");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg_variant}, result_variant);

      //      REQUIRE(std::get<int64_t>(result_variant) == std::lrint(arg));
      auto result = std::lrint(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }
  }

  SECTION("(double, double) -> double")
  {
    double arg0 = 3;
    double arg1 = 2;

    ASTNodeDataVariant arg0_variant = arg0;
    ASTNodeDataVariant arg1_variant = arg1;
    ASTNodeDataVariant result_variant;

    SECTION("atan2")
    {
      auto i_function = name_cfunction.find("atan2");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg0_variant, arg1_variant}, result_variant);

      auto result = std::atan2(arg0, arg1);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("pow")
    {
      auto i_function = name_cfunction.find("pow");
      REQUIRE(i_function != name_cfunction.end());

      ICFunctionEmbedder& function_embedder = *i_function->second;

      function_embedder.apply({arg0_variant, arg1_variant}, result_variant);

      auto result = std::pow(arg0, arg1);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }
  }
}
