From a95943c774555673bcfb5fded41c1a5a131dad4e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Wed, 1 Sep 2021 17:50:54 +0200
Subject: [PATCH] Add min and max functions to the language

Actually both of them is twice defined:
- min : Z*Z-> Z
- min : R*R-> R
- max : Z*Z-> Z
- max : R*R-> R
---
 src/language/modules/MathModule.cpp     | 12 +++++
 tests/test_BuiltinFunctionProcessor.cpp | 36 +++++++++++++++
 tests/test_MathModule.cpp               | 59 ++++++++++++++++++++++++-
 3 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp
index 69f6890fb..2217424e4 100644
--- a/src/language/modules/MathModule.cpp
+++ b/src/language/modules/MathModule.cpp
@@ -70,6 +70,18 @@ MathModule::MathModule()
   this->_addBuiltinFunction("round", std::make_shared<BuiltinFunctionEmbedder<int64_t(double)>>(
                                        [](double x) -> int64_t { return std::lround(x); }));
 
+  this->_addBuiltinFunction("min", std::make_shared<BuiltinFunctionEmbedder<double(double, double)>>(
+                                     [](double x, double y) -> double { return std::min(x, y); }));
+
+  this->_addBuiltinFunction("min", std::make_shared<BuiltinFunctionEmbedder<int64_t(int64_t, int64_t)>>(
+                                     [](int64_t x, int64_t y) -> int64_t { return std::min(x, y); }));
+
+  this->_addBuiltinFunction("max", std::make_shared<BuiltinFunctionEmbedder<double(double, double)>>(
+                                     [](double x, double y) -> double { return std::max(x, y); }));
+
+  this->_addBuiltinFunction("max", std::make_shared<BuiltinFunctionEmbedder<int64_t(int64_t, int64_t)>>(
+                                     [](int64_t x, int64_t y) -> int64_t { return std::max(x, y); }));
+
   this->_addBuiltinFunction("dot",
                             std::make_shared<BuiltinFunctionEmbedder<double(const TinyVector<1>, const TinyVector<1>)>>(
                               [](const TinyVector<1> x, const TinyVector<1> y) -> double { return dot(x, y); }));
diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp
index a1fb3fe82..bbc5ac52e 100644
--- a/tests/test_BuiltinFunctionProcessor.cpp
+++ b/tests/test_BuiltinFunctionProcessor.cpp
@@ -278,6 +278,42 @@ let z:Z, z = round(-1.2);
       CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1});
     }
 
+    {   // min
+      tested_function_set.insert("min:R*R");
+      std::string_view data = R"(
+import math;
+let x:R, x = min(-2,2.3);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", double{-2});
+    }
+
+    {   // min
+      tested_function_set.insert("min:Z*Z");
+      std::string_view data = R"(
+import math;
+let z:Z, z = min(-1,2);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1});
+    }
+
+    {   // max
+      tested_function_set.insert("max:R*R");
+      std::string_view data = R"(
+import math;
+let x:R, x = max(-1,2.3);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", double{2.3});
+    }
+
+    {   // max
+      tested_function_set.insert("max:Z*Z");
+      std::string_view data = R"(
+import math;
+let z:Z, z = max(-1,2);
+)";
+      CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "z", int64_t{2});
+    }
+
     {   // dot
       tested_function_set.insert("dot:R^1*R^1");
       std::string_view data = R"(
diff --git a/tests/test_MathModule.cpp b/tests/test_MathModule.cpp
index 7c2a9bc54..4d71927f1 100644
--- a/tests/test_MathModule.cpp
+++ b/tests/test_MathModule.cpp
@@ -15,7 +15,7 @@ TEST_CASE("MathModule", "[language]")
   MathModule math_module;
   const auto& name_builtin_function = math_module.getNameBuiltinFunctionMap();
 
-  REQUIRE(name_builtin_function.size() == 25);
+  REQUIRE(name_builtin_function.size() == 29);
 
   SECTION("double -> double")
   {
@@ -323,6 +323,63 @@ TEST_CASE("MathModule", "[language]")
       auto result = std::pow(arg0, arg1);
       REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result));
     }
+
+    SECTION("min")
+    {
+      auto i_function = name_builtin_function.find("min:R*R");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      auto result = std::min(arg0, arg1);
+      REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result));
+    }
+
+    SECTION("max")
+    {
+      auto i_function = name_builtin_function.find("max:R*R");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      auto result = std::max(arg0, arg1);
+      REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result));
+    }
+  }
+
+  SECTION("(uint64_t, uint64_t) -> uint64_t")
+  {
+    int64_t arg0 = 3;
+    int64_t arg1 = -2;
+
+    DataVariant arg0_variant = arg0;
+    DataVariant arg1_variant = arg1;
+
+    SECTION("min")
+    {
+      auto i_function = name_builtin_function.find("min:Z*Z");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      auto result = std::min(arg0, arg1);
+      REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result));
+    }
+
+    SECTION("max")
+    {
+      auto i_function = name_builtin_function.find("max:Z*Z");
+      REQUIRE(i_function != name_builtin_function.end());
+
+      IBuiltinFunctionEmbedder& function_embedder = *i_function->second;
+      DataVariant result_variant                  = function_embedder.apply({arg0_variant, arg1_variant});
+
+      auto result = std::max(arg0, arg1);
+      REQUIRE(std::get<decltype(result)>(result_variant) == Catch::Approx(result));
+    }
   }
 
   SECTION("(R^d, R^d) -> double")
-- 
GitLab