Skip to content
Snippets Groups Projects
Commit 9532d901 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tests for CFunctionEmbedder

parent ef148fef
No related branches found
No related tags found
1 merge request!37Feature/language
......@@ -18,12 +18,14 @@
class ICFunctionEmbedder
{
public:
virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) = 0;
virtual size_t numberOfArguments() const = 0;
virtual ASTNodeDataType getReturnDataType() const = 0;
virtual std::vector<ASTNodeDataType> getArgumentDataTypes() const = 0;
virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const = 0;
virtual ~ICFunctionEmbedder() = default;
};
......@@ -43,8 +45,7 @@ class CFunctionEmbedder : public ICFunctionEmbedder
if constexpr (std::is_arithmetic_v<decltype(v_i)>) {
std::get<I>(t) = v_i;
} else {
std::cerr << __FILE__ << ':' << __LINE__ << ": unexpected argument type!\n";
std::exit(1);
throw std::runtime_error("unexpected argument type!");
}
},
v[I]);
......@@ -75,13 +76,13 @@ class CFunctionEmbedder : public ICFunctionEmbedder
}
public:
ASTNodeDataType
getReturnDataType() const
PUGS_INLINE ASTNodeDataType
getReturnDataType() const final
{
return ast_node_data_type_from_pod<FX>;
}
std::vector<ASTNodeDataType>
PUGS_INLINE std::vector<ASTNodeDataType>
getArgumentDataTypes() const final
{
constexpr size_t N = std::tuple_size_v<ArgsTuple>;
......@@ -91,15 +92,15 @@ class CFunctionEmbedder : public ICFunctionEmbedder
return this->_getArgumentDataTypes(t, IndexSequence{});
}
PUGS_INLINE constexpr size_t
numberOfArguments() const
PUGS_INLINE size_t
numberOfArguments() const final
{
return sizeof...(Args);
}
PUGS_INLINE
void
apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) final
apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const final
{
constexpr size_t N = std::tuple_size_v<ArgsTuple>;
ArgsTuple t;
......
......@@ -36,6 +36,7 @@ add_executable (unit_tests
test_BinaryExpressionProcessor_equality.cpp
test_BinaryExpressionProcessor_logic.cpp
test_BiCGStab.cpp
test_CFunctionEmbedder.cpp
test_ContinueProcessor.cpp
test_ConcatExpressionProcessor.cpp
test_CRSMatrix.cpp
......
#include <catch2/catch.hpp>
#include <CFunctionEmbedder.hpp>
TEST_CASE("CFunctionEbedder", "[language]")
{
rang::setControlMode(rang::control::Off);
SECTION("math")
{
CFunctionEmbedder<double, double> embedded_sin{
std::function<double(double)>{[](double x) -> double { return std::sin(x); }}};
double arg = 2;
ASTNodeDataVariant arg_variant = arg;
ASTNodeDataVariant result;
embedded_sin.apply({arg_variant}, result);
REQUIRE(std::get<double>(result) == std::sin(arg));
REQUIRE(embedded_sin.numberOfArguments() == 1);
REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t);
REQUIRE(embedded_sin.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
}
SECTION("multiple variant args")
{
std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
CFunctionEmbedder<bool, double, uint64_t> embedded_c{c};
double d_arg = 2.3;
uint64_t i_arg = 3;
std::vector<ASTNodeDataVariant> args;
args.push_back(d_arg);
args.push_back(i_arg);
ASTNodeDataVariant result;
embedded_c.apply(args, result);
REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
REQUIRE(embedded_c.numberOfArguments() == 2);
REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t);
REQUIRE(embedded_c.getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
REQUIRE(embedded_c.getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
}
SECTION("ICFunctionEmbedder")
{
std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; };
std::unique_ptr<ICFunctionEmbedder> i_embedded_c = std::make_unique<CFunctionEmbedder<bool, double, uint64_t>>(c);
double d_arg = 2.3;
uint64_t i_arg = 3;
std::vector<ASTNodeDataVariant> args;
args.push_back(d_arg);
args.push_back(i_arg);
ASTNodeDataVariant result;
i_embedded_c->apply(args, result);
REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
REQUIRE(i_embedded_c->numberOfArguments() == 2);
REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t);
REQUIRE(i_embedded_c->getArgumentDataTypes()[0] == ASTNodeDataType::double_t);
REQUIRE(i_embedded_c->getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t);
}
SECTION("error")
{
std::function<bool(double)> positive = [](double x) -> bool { return x >= 0; };
CFunctionEmbedder<bool, double> embedded_positive{positive};
std::string arg = std::string{"2.3"};
std::vector<ASTNodeDataVariant> args;
args.push_back(arg);
ASTNodeDataVariant result;
REQUIRE_THROWS(embedded_positive.apply(args, result));
// REQUIRE(std::get<bool>(result) == c(d_arg, i_arg));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment