From a24be6df31fc1d58edafd45159b1cf4d5859de42 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Mon, 17 Feb 2020 11:56:43 +0100
Subject: [PATCH] Update math module so that provided functions satisfy natural
 casts

That is: `ceil`, `floor`, `trunc` and `round` are now `R -> Z` functions.

`nearbyint`, `lround`, `rint` and `lrint` functions have been removed since they
were redundant.
---
 src/language/CMathModule.cpp      |  28 ++-----
 tests/test_CFunctionProcessor.cpp | 117 +++++++++++-------------------
 tests/test_CMathModule.cpp        |  69 +++---------------
 3 files changed, 61 insertions(+), 153 deletions(-)

diff --git a/src/language/CMathModule.cpp b/src/language/CMathModule.cpp
index 1363d302a..26d8a8f13 100644
--- a/src/language/CMathModule.cpp
+++ b/src/language/CMathModule.cpp
@@ -74,27 +74,15 @@ CMathModule::CMathModule()
                      std::make_shared<CFunctionEmbedder<double, double, double>>(std::function<double(double, double)>{
                        [](double x, double y) -> double { return std::pow(x, y); }}));
 
-  this->_addFunction("nearbyint", std::make_shared<CFunctionEmbedder<double, double>>(std::function<double(double)>{
-                                    [](double x) -> double { return std::nearbyint(x); }}));
+  this->_addFunction("ceil", std::make_shared<CFunctionEmbedder<int64_t, double>>(
+                               std::function<int64_t(double)>{[](double x) -> int64_t { return std::ceil(x); }}));
 
-  this->_addFunction("ceil", std::make_shared<CFunctionEmbedder<double, double>>(
-                               std::function<double(double)>{[](double x) -> double { return std::ceil(x); }}));
+  this->_addFunction("floor", std::make_shared<CFunctionEmbedder<int64_t, double>>(
+                                std::function<int64_t(double)>{[](double x) -> int64_t { return std::floor(x); }}));
 
-  this->_addFunction("floor", std::make_shared<CFunctionEmbedder<double, double>>(
-                                std::function<double(double)>{[](double x) -> double { return std::floor(x); }}));
+  this->_addFunction("trunc", std::make_shared<CFunctionEmbedder<int64_t, double>>(
+                                std::function<int64_t(double)>{[](double x) -> int64_t { return std::trunc(x); }}));
 
-  this->_addFunction("trunc", std::make_shared<CFunctionEmbedder<double, double>>(
-                                std::function<double(double)>{[](double x) -> double { return std::trunc(x); }}));
-
-  this->_addFunction("round", std::make_shared<CFunctionEmbedder<double, double>>(
-                                std::function<double(double)>{[](double x) -> double { return std::round(x); }}));
-
-  this->_addFunction("lround", std::make_shared<CFunctionEmbedder<int64_t, double>>(
-                                 std::function<int64_t(double)>{[](double x) -> int64_t { return std::lround(x); }}));
-
-  this->_addFunction("rint", std::make_shared<CFunctionEmbedder<double, double>>(
-                               std::function<double(double)>{[](double x) -> double { return std::rint(x); }}));
-
-  this->_addFunction("lrint", std::make_shared<CFunctionEmbedder<int64_t, double>>(
-                                std::function<int64_t(double)>{[](double x) -> int64_t { return std::lrint(x); }}));
+  this->_addFunction("round", std::make_shared<CFunctionEmbedder<int64_t, double>>(
+                                std::function<int64_t(double)>{[](double x) -> int64_t { return std::lround(x); }}));
 }
diff --git a/tests/test_CFunctionProcessor.cpp b/tests/test_CFunctionProcessor.cpp
index 7f47ada23..983362b34 100644
--- a/tests/test_CFunctionProcessor.cpp
+++ b/tests/test_CFunctionProcessor.cpp
@@ -14,36 +14,41 @@
 
 #include <CMathModule.hpp>
 
-#define CHECK_CFUNCTION_EVALUATION_RESULT(data, variable_name, expected_value) \
-  {                                                                            \
-    string_input input{data, "test.pgs"};                                      \
-    auto ast = ASTBuilder::build(input);                                       \
-                                                                               \
-    ASTModulesImporter{*ast};                                                  \
-    ASTNodeTypeCleaner<language::import_instruction>{*ast};                    \
-                                                                               \
-    ASTSymbolTableBuilder{*ast};                                               \
-    ASTNodeDataTypeBuilder{*ast};                                              \
-                                                                               \
-    ASTNodeDeclarationToAffectationConverter{*ast};                            \
-    ASTNodeTypeCleaner<language::declaration>{*ast};                           \
-    ASTNodeTypeCleaner<language::let_declaration>{*ast};                       \
-                                                                               \
-    ASTNodeExpressionBuilder{*ast};                                            \
-    ExecutionPolicy exec_policy;                                               \
-    ast->execute(exec_policy);                                                 \
-                                                                               \
-    auto symbol_table = ast->m_symbol_table;                                   \
-                                                                               \
-    using namespace TAO_PEGTL_NAMESPACE;                                       \
-    position use_position{internal::iterator{"fixture"}, "fixture"};           \
-    use_position.byte    = 10000;                                              \
-    auto [symbol, found] = symbol_table->find(variable_name, use_position);    \
-                                                                               \
-    auto attributes = symbol->attributes();                                    \
-    auto value      = std::get<decltype(expected_value)>(attributes.value());  \
-                                                                               \
-    REQUIRE(value == expected_value);                                          \
+#define CHECK_CFUNCTION_EVALUATION_RESULT(data, variable_name, expected_value)                       \
+  {                                                                                                  \
+    string_input input{data, "test.pgs"};                                                            \
+    auto ast = ASTBuilder::build(input);                                                             \
+                                                                                                     \
+    ASTModulesImporter{*ast};                                                                        \
+    ASTNodeTypeCleaner<language::import_instruction>{*ast};                                          \
+                                                                                                     \
+    ASTSymbolTableBuilder{*ast};                                                                     \
+    ASTNodeDataTypeBuilder{*ast};                                                                    \
+                                                                                                     \
+    ASTNodeDeclarationToAffectationConverter{*ast};                                                  \
+    ASTNodeTypeCleaner<language::declaration>{*ast};                                                 \
+    ASTNodeTypeCleaner<language::let_declaration>{*ast};                                             \
+                                                                                                     \
+    ASTNodeExpressionBuilder{*ast};                                                                  \
+    ExecutionPolicy exec_policy;                                                                     \
+    ast->execute(exec_policy);                                                                       \
+                                                                                                     \
+    auto symbol_table = ast->m_symbol_table;                                                         \
+                                                                                                     \
+    using namespace TAO_PEGTL_NAMESPACE;                                                             \
+    position use_position{internal::iterator{"fixture"}, "fixture"};                                 \
+    use_position.byte    = 10000;                                                                    \
+    auto [symbol, found] = symbol_table->find(variable_name, use_position);                          \
+                                                                                                     \
+    using namespace Catch::Matchers;                                                                 \
+                                                                                                     \
+    REQUIRE_THAT(found, Predicate<bool>([](const bool& found) -> bool { return found; },             \
+                                        std::string{"Cannot find symbol '"} + variable_name + "'")); \
+                                                                                                     \
+    auto attributes = symbol->attributes();                                                          \
+    auto value      = std::get<decltype(expected_value)>(attributes.value());                        \
+                                                                                                     \
+    REQUIRE(value == expected_value);                                                                \
   }
 
 TEST_CASE("CFunctionProcessor", "[language]")
@@ -216,76 +221,40 @@ R x = pow(1.6, 2.3);
       CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{std::pow(1.6, 2.3)});
     }
 
-    {   // exp
-      tested_function_set.insert("nearbyint");
-      std::string_view data = R"(
-import math;
-R x = nearbyint(1.7);
-)";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{2});
-    }
-
     {   // ceil
       tested_function_set.insert("ceil");
       std::string_view data = R"(
 import math;
-R x = ceil(-1.2);
+Z z = ceil(-1.2);
 )";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{-1});
+      CHECK_CFUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1});
     }
 
     {   // floor
       tested_function_set.insert("floor");
       std::string_view data = R"(
 import math;
-R x = floor(-1.2);
+Z z = floor(-1.2);
 )";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{-2});
+      CHECK_CFUNCTION_EVALUATION_RESULT(data, "z", int64_t{-2});
     }
 
     {   // trunc
       tested_function_set.insert("trunc");
       std::string_view data = R"(
 import math;
-R x = trunc(-0.2) + trunc(0.7);
+Z z = trunc(-0.2) + trunc(0.7);
 )";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{0});
+      CHECK_CFUNCTION_EVALUATION_RESULT(data, "z", int64_t{0});
     }
 
     {   // round
       tested_function_set.insert("round");
       std::string_view data = R"(
 import math;
-R x = round(-1.2);
-)";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{-1});
-    }
-
-    {   // lround
-      tested_function_set.insert("lround");
-      std::string_view data = R"(
-import math;
-Z i = lround(-1.2);
-)";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "i", int64_t{-1});
-    }
-
-    {   // rint
-      tested_function_set.insert("rint");
-      std::string_view data = R"(
-import math;
-R x = rint(-1.2);
-)";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "x", double{-1});
-    }
-
-    {   // lrint
-      tested_function_set.insert("lrint");
-      std::string_view data = R"(
-import math;
-Z i = lrint(-1.2);
+Z z = round(-1.2);
 )";
-      CHECK_CFUNCTION_EVALUATION_RESULT(data, "i", int64_t{-1});
+      CHECK_CFUNCTION_EVALUATION_RESULT(data, "z", int64_t{-1});
     }
 
     CMathModule math_module;
diff --git a/tests/test_CMathModule.cpp b/tests/test_CMathModule.cpp
index 657a5bced..f6365bdd7 100644
--- a/tests/test_CMathModule.cpp
+++ b/tests/test_CMathModule.cpp
@@ -12,7 +12,7 @@ TEST_CASE("CMathModule", "[language]")
   CMathModule math_module;
   const auto& name_cfunction = math_module.getNameCFunctionsMap();
 
-  REQUIRE(name_cfunction.size() == 26);
+  REQUIRE(name_cfunction.size() == 22);
 
   SECTION("double -> double")
   {
@@ -233,18 +233,13 @@ TEST_CASE("CMathModule", "[language]")
       auto result = std::log(arg);
       REQUIRE(std::get<decltype(result)>(result_variant) == Approx(result));
     }
+  }
 
-    SECTION("nearbyint")
-    {
-      auto i_function = name_cfunction.find("nearbyint");
-      REQUIRE(i_function != name_cfunction.end());
-
-      ICFunctionEmbedder& function_embedder = *i_function->second;
-      DataVariant result_variant            = function_embedder.apply({arg_variant});
+  SECTION("double -> int64_t")
+  {
+    double arg = 1.3;
 
-      auto result = std::nearbyint(arg);
-      REQUIRE(std::get<decltype(result)>(result_variant) == result);
-    }
+    DataVariant arg_variant = arg;
 
     SECTION("ceil")
     {
@@ -254,7 +249,7 @@ TEST_CASE("CMathModule", "[language]")
       ICFunctionEmbedder& function_embedder = *i_function->second;
       DataVariant result_variant            = function_embedder.apply({arg_variant});
 
-      auto result = std::ceil(arg);
+      int64_t result = std::ceil(arg);
       REQUIRE(std::get<decltype(result)>(result_variant) == result);
     }
 
@@ -266,7 +261,7 @@ TEST_CASE("CMathModule", "[language]")
       ICFunctionEmbedder& function_embedder = *i_function->second;
       DataVariant result_variant            = function_embedder.apply({arg_variant});
 
-      auto result = std::floor(arg);
+      int64_t result = std::floor(arg);
       REQUIRE(std::get<decltype(result)>(result_variant) == result);
     }
 
@@ -278,7 +273,7 @@ TEST_CASE("CMathModule", "[language]")
       ICFunctionEmbedder& function_embedder = *i_function->second;
       DataVariant result_variant            = function_embedder.apply({arg_variant});
 
-      auto result = std::trunc(arg);
+      int64_t result = std::trunc(arg);
       REQUIRE(std::get<decltype(result)>(result_variant) == result);
     }
 
@@ -290,51 +285,7 @@ TEST_CASE("CMathModule", "[language]")
       ICFunctionEmbedder& function_embedder = *i_function->second;
       DataVariant result_variant            = function_embedder.apply({arg_variant});
 
-      auto result = std::round(arg);
-      REQUIRE(std::get<decltype(result)>(result_variant) == result);
-    }
-
-    SECTION("rint")
-    {
-      auto i_function = name_cfunction.find("rint");
-      REQUIRE(i_function != name_cfunction.end());
-
-      ICFunctionEmbedder& function_embedder = *i_function->second;
-      DataVariant result_variant            = function_embedder.apply({arg_variant});
-
-      auto result = std::rint(arg);
-      REQUIRE(std::get<decltype(result)>(result_variant) == result);
-    }
-  }
-
-  SECTION("double -> int64_t")
-  {
-    double arg = 1.3;
-
-    DataVariant arg_variant = arg;
-
-    SECTION("lround")
-    {
-      auto i_function = name_cfunction.find("lround");
-      REQUIRE(i_function != name_cfunction.end());
-
-      ICFunctionEmbedder& function_embedder = *i_function->second;
-      DataVariant result_variant            = function_embedder.apply({arg_variant});
-
-      auto result = std::lround(arg);
-      REQUIRE(std::get<decltype(result)>(result_variant) == result);
-    }
-
-    SECTION("lrint")
-    {
-      auto i_function = name_cfunction.find("lrint");
-      REQUIRE(i_function != name_cfunction.end());
-
-      ICFunctionEmbedder& function_embedder = *i_function->second;
-      DataVariant result_variant            = function_embedder.apply({arg_variant});
-
-      //      REQUIRE(std::get<int64_t>(result_variant) == std::lrint(arg));
-      auto result = std::lrint(arg);
+      int64_t result = std::lround(arg);
       REQUIRE(std::get<decltype(result)>(result_variant) == result);
     }
   }
-- 
GitLab