From 9532d901f4662a694c34719609decbf0258de2b8 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 30 Oct 2019 18:47:11 +0100
Subject: [PATCH] Add tests for CFunctionEmbedder

---
 src/language/CFunctionEmbedder.hpp | 19 +++---
 tests/CMakeLists.txt               |  1 +
 tests/test_CFunctionEmbedder.cpp   | 95 ++++++++++++++++++++++++++++++
 3 files changed, 106 insertions(+), 9 deletions(-)
 create mode 100644 tests/test_CFunctionEmbedder.cpp

diff --git a/src/language/CFunctionEmbedder.hpp b/src/language/CFunctionEmbedder.hpp
index 7ca5e10df..10a7209ba 100644
--- a/src/language/CFunctionEmbedder.hpp
+++ b/src/language/CFunctionEmbedder.hpp
@@ -18,12 +18,14 @@
 class ICFunctionEmbedder
 {
  public:
-  virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) = 0;
+  virtual size_t numberOfArguments() const = 0;
 
   virtual ASTNodeDataType getReturnDataType() const = 0;
 
   virtual std::vector<ASTNodeDataType> getArgumentDataTypes() const = 0;
 
+  virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const = 0;
+
   virtual ~ICFunctionEmbedder() = default;
 };
 
@@ -43,8 +45,7 @@ class CFunctionEmbedder : public ICFunctionEmbedder
         if constexpr (std::is_arithmetic_v<decltype(v_i)>) {
           std::get<I>(t) = v_i;
         } else {
-          std::cerr << __FILE__ << ':' << __LINE__ << ": unexpected argument type!\n";
-          std::exit(1);
+          throw std::runtime_error("unexpected argument type!");
         }
       },
       v[I]);
@@ -75,13 +76,13 @@ class CFunctionEmbedder : public ICFunctionEmbedder
   }
 
  public:
-  ASTNodeDataType
-  getReturnDataType() const
+  PUGS_INLINE ASTNodeDataType
+  getReturnDataType() const final
   {
     return ast_node_data_type_from_pod<FX>;
   }
 
-  std::vector<ASTNodeDataType>
+  PUGS_INLINE std::vector<ASTNodeDataType>
   getArgumentDataTypes() const final
   {
     constexpr size_t N = std::tuple_size_v<ArgsTuple>;
@@ -91,15 +92,15 @@ class CFunctionEmbedder : public ICFunctionEmbedder
     return this->_getArgumentDataTypes(t, IndexSequence{});
   }
 
-  PUGS_INLINE constexpr size_t
-  numberOfArguments() const
+  PUGS_INLINE size_t
+  numberOfArguments() const final
   {
     return sizeof...(Args);
   }
 
   PUGS_INLINE
   void
-  apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) final
+  apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const final
   {
     constexpr size_t N = std::tuple_size_v<ArgsTuple>;
     ArgsTuple t;
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 952e599ed..ed9f5564d 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -36,6 +36,7 @@ add_executable (unit_tests
   test_BinaryExpressionProcessor_equality.cpp
   test_BinaryExpressionProcessor_logic.cpp
   test_BiCGStab.cpp
+  test_CFunctionEmbedder.cpp
   test_ContinueProcessor.cpp
   test_ConcatExpressionProcessor.cpp
   test_CRSMatrix.cpp
diff --git a/tests/test_CFunctionEmbedder.cpp b/tests/test_CFunctionEmbedder.cpp
new file mode 100644
index 000000000..0fe8eea89
--- /dev/null
+++ b/tests/test_CFunctionEmbedder.cpp
@@ -0,0 +1,95 @@
+#include <catch2/catch.hpp>
+
+#include <CFunctionEmbedder.hpp>
+
+TEST_CASE("CFunctionEbedder", "[language]")
+{
+  rang::setControlMode(rang::control::Off);
+
+  SECTION("math")
+  {
+    CFunctionEmbedder<double, double> embedded_sin{
+      std::function<double(double)>{[](double x) -> double { return std::sin(x); }}};
+
+    double arg                     = 2;
+    ASTNodeDataVariant arg_variant = arg;
+
+    ASTNodeDataVariant result;
+
+    embedded_sin.apply({arg_variant}, result);
+
+    REQUIRE(std::get<double>(result) == std::sin(arg));
+    REQUIRE(embedded_sin.numberOfArguments() == 1);
+
+    REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
+    REQUIRE(embedded_sin.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
+  }
+
+  SECTION("multiple variant args")
+  {
+    std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
+
+    CFunctionEmbedder<bool, double, uint64_t> embedded_c{c};
+
+    double d_arg   = 2.3;
+    uint64_t i_arg = 3;
+
+    std::vector<ASTNodeDataVariant> args;
+    args.push_back(d_arg);
+    args.push_back(i_arg);
+
+    ASTNodeDataVariant result;
+
+    embedded_c.apply(args, result);
+
+    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
+    REQUIRE(embedded_c.numberOfArguments() == 2);
+
+    REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
+    REQUIRE(embedded_c.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
+    REQUIRE(embedded_c.getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
+  }
+
+  SECTION("ICFunctionEmbedder")
+  {
+    std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
+
+    std::unique_ptr<ICFunctionEmbedder> i_embedded_c = std::make_unique<CFunctionEmbedder<bool, double, uint64_t>>(c);
+
+    double d_arg   = 2.3;
+    uint64_t i_arg = 3;
+
+    std::vector<ASTNodeDataVariant> args;
+    args.push_back(d_arg);
+    args.push_back(i_arg);
+
+    ASTNodeDataVariant result;
+
+    i_embedded_c->apply(args, result);
+
+    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
+    REQUIRE(i_embedded_c->numberOfArguments() == 2);
+
+    REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
+    REQUIRE(i_embedded_c->getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
+    REQUIRE(i_embedded_c->getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
+  }
+
+  SECTION("error")
+  {
+    std::function<bool(double)> positive = [](double x) -> bool { return x >= 0; };
+
+    CFunctionEmbedder<bool, double> embedded_positive{positive};
+
+    std::string arg = std::string{"2.3"};
+
+    std::vector<ASTNodeDataVariant> args;
+    args.push_back(arg);
+
+    ASTNodeDataVariant result;
+
+    REQUIRE_THROWS(embedded_positive.apply(args, result));
+
+    //    REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
+  }
+}
-- 
GitLab