Skip to content
Snippets Groups Projects
Commit 9532d901 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tests for CFunctionEmbedder

parent ef148fef
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -18,12 +18,14 @@ ...@@ -18,12 +18,14 @@
class ICFunctionEmbedder class ICFunctionEmbedder
{ {
public: public:
virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) = 0; virtual size_t numberOfArguments() const = 0;
virtual ASTNodeDataType getReturnDataType() const = 0; virtual ASTNodeDataType getReturnDataType() const = 0;
virtual std::vector<ASTNodeDataType> getArgumentDataTypes() const = 0; virtual std::vector<ASTNodeDataType> getArgumentDataTypes() const = 0;
virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const = 0;
virtual ~ICFunctionEmbedder() = default; virtual ~ICFunctionEmbedder() = default;
}; };
...@@ -43,8 +45,7 @@ class CFunctionEmbedder : public ICFunctionEmbedder ...@@ -43,8 +45,7 @@ class CFunctionEmbedder : public ICFunctionEmbedder
if constexpr (std::is_arithmetic_v<decltype(v_i)>) { if constexpr (std::is_arithmetic_v<decltype(v_i)>) {
std::get<I>(t) = v_i; std::get<I>(t) = v_i;
} else { } else {
std::cerr << __FILE__ << ':' << __LINE__ << ": unexpected argument type!\n"; throw std::runtime_error("unexpected argument type!");
std::exit(1);
} }
}, },
v[I]); v[I]);
...@@ -75,13 +76,13 @@ class CFunctionEmbedder : public ICFunctionEmbedder ...@@ -75,13 +76,13 @@ class CFunctionEmbedder : public ICFunctionEmbedder
} }
public: public:
ASTNodeDataType PUGS_INLINE ASTNodeDataType
getReturnDataType() const getReturnDataType() const final
{ {
return ast_node_data_type_from_pod<FX>; return ast_node_data_type_from_pod<FX>;
} }
std::vector<ASTNodeDataType> PUGS_INLINE std::vector<ASTNodeDataType>
getArgumentDataTypes() const final getArgumentDataTypes() const final
{ {
constexpr size_t N = std::tuple_size_v<ArgsTuple>; constexpr size_t N = std::tuple_size_v<ArgsTuple>;
...@@ -91,15 +92,15 @@ class CFunctionEmbedder : public ICFunctionEmbedder ...@@ -91,15 +92,15 @@ class CFunctionEmbedder : public ICFunctionEmbedder
return this->_getArgumentDataTypes(t, IndexSequence{}); return this->_getArgumentDataTypes(t, IndexSequence{});
} }
PUGS_INLINE constexpr size_t PUGS_INLINE size_t
numberOfArguments() const numberOfArguments() const final
{ {
return sizeof...(Args); return sizeof...(Args);
} }
PUGS_INLINE PUGS_INLINE
void void
apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) final apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const final
{ {
constexpr size_t N = std::tuple_size_v<ArgsTuple>; constexpr size_t N = std::tuple_size_v<ArgsTuple>;
ArgsTuple t; ArgsTuple t;
......
...@@ -36,6 +36,7 @@ add_executable (unit_tests ...@@ -36,6 +36,7 @@ add_executable (unit_tests
test_BinaryExpressionProcessor_equality.cpp test_BinaryExpressionProcessor_equality.cpp
test_BinaryExpressionProcessor_logic.cpp test_BinaryExpressionProcessor_logic.cpp
test_BiCGStab.cpp test_BiCGStab.cpp
test_CFunctionEmbedder.cpp
test_ContinueProcessor.cpp test_ContinueProcessor.cpp
test_ConcatExpressionProcessor.cpp test_ConcatExpressionProcessor.cpp
test_CRSMatrix.cpp test_CRSMatrix.cpp
......
#include <catch2/catch.hpp>
#include <CFunctionEmbedder.hpp>
TEST_CASE("CFunctionEbedder", "[language]")
{
rang::setControlMode(rang::control::Off);
SECTION("math")
{
CFunctionEmbedder<double, double> embedded_sin{
std::function<double(double)>{[](double x) -> double { return std::sin(x); }}};
double arg = 2;
ASTNodeDataVariant arg_variant = arg;
ASTNodeDataVariant result;
embedded_sin.apply({arg_variant}, result);
REQUIRE(std::get<double>(result) == std::sin(arg));
REQUIRE(embedded_sin.numberOfArguments() == 1);
REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
REQUIRE(embedded_sin.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
}
SECTION("multiple variant args")
{
std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
CFunctionEmbedder<bool, double, uint64_t> embedded_c{c};
double d_arg = 2.3;
uint64_t i_arg = 3;
std::vector<ASTNodeDataVariant> args;
args.push_back(d_arg);
args.push_back(i_arg);
ASTNodeDataVariant result;
embedded_c.apply(args, result);
REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
REQUIRE(embedded_c.numberOfArguments() == 2);
REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
REQUIRE(embedded_c.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
REQUIRE(embedded_c.getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
}
SECTION("ICFunctionEmbedder")
{
std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
std::unique_ptr<ICFunctionEmbedder> i_embedded_c = std::make_unique<CFunctionEmbedder<bool, double, uint64_t>>(c);
double d_arg = 2.3;
uint64_t i_arg = 3;
std::vector<ASTNodeDataVariant> args;
args.push_back(d_arg);
args.push_back(i_arg);
ASTNodeDataVariant result;
i_embedded_c->apply(args, result);
REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
REQUIRE(i_embedded_c->numberOfArguments() == 2);
REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
REQUIRE(i_embedded_c->getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
REQUIRE(i_embedded_c->getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
}
SECTION("error")
{
std::function<bool(double)> positive = [](double x) -> bool { return x >= 0; };
CFunctionEmbedder<bool, double> embedded_positive{positive};
std::string arg = std::string{"2.3"};
std::vector<ASTNodeDataVariant> args;
args.push_back(arg);
ASTNodeDataVariant result;
REQUIRE_THROWS(embedded_positive.apply(args, result));
// REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment