#include <catch2/catch.hpp>

#include <BuiltinFunctionEmbedder.hpp>
#include <EmbedderTable.hpp>

#include <MathModule.hpp>

TEST_CASE("MathModule", "[language]")
{
  rang::setControlMode(rang::control::Off);

  MathModule math_module;
  const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap();

  REQUIRE(name_builtin_function.size() == 22);

  SECTION("double -> double")
  {
    double arg = 0.7;

    DataVariant arg_variant = arg;

    SECTION("sqrt")
    {
      auto i_function = name_builtin_function.find("sqrt");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::sqrt(arg);

      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("abs")
    {
      auto i_function = name_builtin_function.find("abs");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      {
        DataVariant result_variant = function_embedder.apply({arg_variant});

        auto result = std::abs(arg);

        REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
      }

      {
        arg = -3;

        DataVariant arg_variant = arg;

        DataVariant result_variant = function_embedder.apply({arg_variant});

        auto result = std::abs(arg);

        REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
      }
    }

    SECTION("sin")
    {
      auto i_function = name_builtin_function.find("sin");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::sin(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("cos")
    {
      auto i_function = name_builtin_function.find("cos");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::cos(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("tan")
    {
      auto i_function = name_builtin_function.find("tan");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::tan(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("asin")
    {
      auto i_function = name_builtin_function.find("asin");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::asin(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("acos")
    {
      auto i_function = name_builtin_function.find("acos");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::acos(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("atan")
    {
      auto i_function = name_builtin_function.find("atan");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::atan(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("sinh")
    {
      arg         = 1.3;
      arg_variant = arg;

      auto i_function = name_builtin_function.find("sinh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::sinh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("cosh")
    {
      auto i_function = name_builtin_function.find("cosh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::cosh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("tanh")
    {
      auto i_function = name_builtin_function.find("tanh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::tanh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("asinh")
    {
      auto i_function = name_builtin_function.find("asinh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::asinh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("acosh")
    {
      arg         = 10;
      arg_variant = arg;

      auto i_function = name_builtin_function.find("acosh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::acosh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("atanh")
    {
      auto i_function = name_builtin_function.find("atanh");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::atanh(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("exp")
    {
      auto i_function = name_builtin_function.find("exp");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::exp(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }

    SECTION("log")
    {
      auto i_function = name_builtin_function.find("log");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      auto result = std::log(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }
  }

  SECTION("double -> int64_t")
  {
    double arg = 1.3;

    DataVariant arg_variant = arg;

    SECTION("ceil")
    {
      auto i_function = name_builtin_function.find("ceil");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      int64_t result = std::ceil(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("floor")
    {
      auto i_function = name_builtin_function.find("floor");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      int64_t result = std::floor(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("trunc")
    {
      auto i_function = name_builtin_function.find("trunc");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      int64_t result = std::trunc(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("round")
    {
      auto i_function = name_builtin_function.find("round");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg_variant});

      int64_t result = std::lround(arg);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }
  }

  SECTION("(double, double) -> double")
  {
    double arg0 = 3;
    double arg1 = 2;

    DataVariant arg0_variant = arg0;
    DataVariant arg1_variant = arg1;

    SECTION("atan2")
    {
      auto i_function = name_builtin_function.find("atan2");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});

      auto result = std::atan2(arg0, arg1);
      REQUIRE(std::get<decltype(result)>(result_variant) == result);
    }

    SECTION("pow")
    {
      auto i_function = name_builtin_function.find("pow");
      REQUIRE(i_function != name_builtin_function.end());

      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});

      auto result = std::pow(arg0, arg1);
      REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
    }
  }
}