diff --git a/src/language/CFunctionEmbedder.hpp b/src/language/CFunctionEmbedder.hpp index 7ca5e10df4b3e0c5a814ca2e2f544e3e47da0dad..10a7209baa13cb9f296741f741624713670f878e 100644 --- a/src/language/CFunctionEmbedder.hpp +++ b/src/language/CFunctionEmbedder.hpp @@ -18,12 +18,14 @@ class ICFunctionEmbedder { public: - virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) = 0; + virtual size_t numberOfArguments() const = 0; virtual ASTNodeDataType getReturnDataType() const = 0; virtual std::vector<ASTNodeDataType> getArgumentDataTypes() const = 0; + virtual void apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const = 0; + virtual ~ICFunctionEmbedder() = default; }; @@ -43,8 +45,7 @@ class CFunctionEmbedder : public ICFunctionEmbedder if constexpr (std::is_arithmetic_v<decltype(v_i)>) { std::get<I>(t) = v_i; } else { - std::cerr << __FILE__ << ':' << __LINE__ << ": unexpected argument type!\n"; - std::exit(1); + throw std::runtime_error("unexpected argument type!"); } }, v[I]); @@ -75,13 +76,13 @@ class CFunctionEmbedder : public ICFunctionEmbedder } public: - ASTNodeDataType - getReturnDataType() const + PUGS_INLINE ASTNodeDataType + getReturnDataType() const final { return ast_node_data_type_from_pod<FX>; } - std::vector<ASTNodeDataType> + PUGS_INLINE std::vector<ASTNodeDataType> getArgumentDataTypes() const final { constexpr size_t N = std::tuple_size_v<ArgsTuple>; @@ -91,15 +92,15 @@ class CFunctionEmbedder : public ICFunctionEmbedder return this->_getArgumentDataTypes(t, IndexSequence{}); } - PUGS_INLINE constexpr size_t - numberOfArguments() const + PUGS_INLINE size_t + numberOfArguments() const final { return sizeof...(Args); } PUGS_INLINE void - apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) final + apply(const std::vector<ASTNodeDataVariant>& x, ASTNodeDataVariant& f_x) const final { constexpr size_t N = std::tuple_size_v<ArgsTuple>; ArgsTuple t; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 952e599ed03ff33588978ee9b2fb7d84df092b61..ed9f5564d2443b311ecc61b6380f4e4cde1d2830 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable (unit_tests test_BinaryExpressionProcessor_equality.cpp test_BinaryExpressionProcessor_logic.cpp test_BiCGStab.cpp + test_CFunctionEmbedder.cpp test_ContinueProcessor.cpp test_ConcatExpressionProcessor.cpp test_CRSMatrix.cpp diff --git a/tests/test_CFunctionEmbedder.cpp b/tests/test_CFunctionEmbedder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0fe8eea896dda0f020063f84915e6ad48ebf35bb --- /dev/null +++ b/tests/test_CFunctionEmbedder.cpp @@ -0,0 +1,95 @@ +#include <catch2/catch.hpp> + +#include <CFunctionEmbedder.hpp> + +TEST_CASE("CFunctionEbedder", "[language]") +{ + rang::setControlMode(rang::control::Off); + + SECTION("math") + { + CFunctionEmbedder<double, double> embedded_sin{ + std::function<double(double)>{[](double x) -> double { return std::sin(x); }}}; + + double arg = 2; + ASTNodeDataVariant arg_variant = arg; + + ASTNodeDataVariant result; + + embedded_sin.apply({arg_variant}, result); + + REQUIRE(std::get<double>(result) == std::sin(arg)); + REQUIRE(embedded_sin.numberOfArguments() == 1); + + REQUIRE(embedded_sin.getReturnDataType() == ASTNodeDataType::double_t); + REQUIRE(embedded_sin.getArgumentDataTypes()[0] == ASTNodeDataType::double_t); + } + + SECTION("multiple variant args") + { + std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; }; + + CFunctionEmbedder<bool, double, uint64_t> embedded_c{c}; + + double d_arg = 2.3; + uint64_t i_arg = 3; + + std::vector<ASTNodeDataVariant> args; + args.push_back(d_arg); + args.push_back(i_arg); + + ASTNodeDataVariant result; + + embedded_c.apply(args, result); + + REQUIRE(std::get<bool>(result) == c(d_arg, i_arg)); + REQUIRE(embedded_c.numberOfArguments() == 2); + + REQUIRE(embedded_c.getReturnDataType() == ASTNodeDataType::bool_t); + REQUIRE(embedded_c.getArgumentDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(embedded_c.getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t); + } + + SECTION("ICFunctionEmbedder") + { + std::function<bool(double, uint64_t)> c = [](double x, uint64_t i) -> bool { return x > i; }; + + std::unique_ptr<ICFunctionEmbedder> i_embedded_c = std::make_unique<CFunctionEmbedder<bool, double, uint64_t>>(c); + + double d_arg = 2.3; + uint64_t i_arg = 3; + + std::vector<ASTNodeDataVariant> args; + args.push_back(d_arg); + args.push_back(i_arg); + + ASTNodeDataVariant result; + + i_embedded_c->apply(args, result); + + REQUIRE(std::get<bool>(result) == c(d_arg, i_arg)); + REQUIRE(i_embedded_c->numberOfArguments() == 2); + + REQUIRE(i_embedded_c->getReturnDataType() == ASTNodeDataType::bool_t); + REQUIRE(i_embedded_c->getArgumentDataTypes()[0] == ASTNodeDataType::double_t); + REQUIRE(i_embedded_c->getArgumentDataTypes()[1] == ASTNodeDataType::unsigned_int_t); + } + + SECTION("error") + { + std::function<bool(double)> positive = [](double x) -> bool { return x >= 0; }; + + CFunctionEmbedder<bool, double> embedded_positive{positive}; + + std::string arg = std::string{"2.3"}; + + std::vector<ASTNodeDataVariant> args; + args.push_back(arg); + + ASTNodeDataVariant result; + + REQUIRE_THROWS(embedded_positive.apply(args, result)); + + // REQUIRE(std::get<bool>(result) == c(d_arg, i_arg)); + } +}