From a95943c774555673bcfb5fded41c1a5a131dad4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com> Date: Wed, 1 Sep 2021 17:50:54 +0200 Subject: [PATCH] Add min and max functions to the language Actually both of them is twice defined: - min : Z*Z-> Z - min : R*R-> R - max : Z*Z-> Z - max : R*R-> R --- src/language/modules/MathModule.cpp | 12 +++++ tests/test_BuiltinFunctionProcessor.cpp | 36 +++++++++++++++ tests/test_MathModule.cpp | 59 ++++++++++++++++++++++++- 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp index 69f6890fb..2217424e4 100644 --- a/src/language/modules/MathModule.cpp +++ b/src/language/modules/MathModule.cpp @@ -70,6 +70,18 @@ MathModule::MathModule() this->_addBuiltinFunction("round", std::make_shared<BuiltinFunctionEmbedder<int64_t(double)>>( [](double x) -> int64_t { return std::lround(x); })); + this->_addBuiltinFunction("min", std::make_shared<BuiltinFunctionEmbedder<double(double, double)>>( + [](double x, double y) -> double { return std::min(x, y); })); + + this->_addBuiltinFunction("min", std::make_shared<BuiltinFunctionEmbedder<int64_t(int64_t, int64_t)>>( + [](int64_t x, int64_t y) -> int64_t { return std::min(x, y); })); + + this->_addBuiltinFunction("max", std::make_shared<BuiltinFunctionEmbedder<double(double, double)>>( + [](double x, double y) -> double { return std::max(x, y); })); + + this->_addBuiltinFunction("max", std::make_shared<BuiltinFunctionEmbedder<int64_t(int64_t, int64_t)>>( + [](int64_t x, int64_t y) -> int64_t { return std::max(x, y); })); + this->_addBuiltinFunction("dot", std::make_shared<BuiltinFunctionEmbedder<double(const TinyVector<1>, const TinyVector<1>)>>( [](const TinyVector<1> x, const TinyVector<1> y) -> double { return dot(x, y); })); diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp index a1fb3fe82..bbc5ac52e 100644 --- a/tests/test_BuiltinFunctionProcessor.cpp +++ b/tests/test_BuiltinFunctionProcessor.cpp @@ -278,6 +278,42 @@ let z:Z, z = round(-1.2); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1}); } + { // min + tested_function_set.insert("min:R*R"); + std::string_view data = R"( +import math; +let x:R, x = min(-2,2.3); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", double{-2}); + } + + { // min + tested_function_set.insert("min:Z*Z"); + std::string_view data = R"( +import math; +let z:Z, z = min(-1,2); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1}); + } + + { // max + tested_function_set.insert("max:R*R"); + std::string_view data = R"( +import math; +let x:R, x = max(-1,2.3); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", double{2.3}); + } + + { // max + tested_function_set.insert("max:Z*Z"); + std::string_view data = R"( +import math; +let z:Z, z = max(-1,2); +)"; + CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{2}); + } + { // dot tested_function_set.insert("dot:R^1*R^1"); std::string_view data = R"( diff --git a/tests/test_MathModule.cpp b/tests/test_MathModule.cpp index 7c2a9bc54..4d71927f1 100644 --- a/tests/test_MathModule.cpp +++ b/tests/test_MathModule.cpp @@ -15,7 +15,7 @@ TEST_CASE("MathModule", "[language]") MathModule math_module; const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap(); - REQUIRE(name_builtin_function.size() == 25); + REQUIRE(name_builtin_function.size() == 29); SECTION("double -> double") { @@ -323,6 +323,63 @@ TEST_CASE("MathModule", "[language]") auto result = std::pow(arg0, arg1); REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); } + + SECTION("min") + { + auto i_function = name_builtin_function.find("min:R*R"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = std::min(arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); + } + + SECTION("max") + { + auto i_function = name_builtin_function.find("max:R*R"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = std::max(arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); + } + } + + SECTION("(uint64_t, uint64_t) -> uint64_t") + { + int64_t arg0 = 3; + int64_t arg1 = -2; + + DataVariant arg0_variant = arg0; + DataVariant arg1_variant = arg1; + + SECTION("min") + { + auto i_function = name_builtin_function.find("min:Z*Z"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = std::min(arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); + } + + SECTION("max") + { + auto i_function = name_builtin_function.find("max:Z*Z"); + REQUIRE(i_function != name_builtin_function.end()); + + IBuiltinFunctionEmbedder& function_embedder = *i_function->second; + DataVariant result_variant = function_embedder.apply({arg0_variant, arg1_variant}); + + auto result = std::max(arg0, arg1); + REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result)); + } } SECTION("(R^d, R^d) -> double") -- GitLab