diff --git a/src/language/node_processor/CFunctionProcessor.hpp b/src/language/node_processor/CFunctionProcessor.hpp index ca2fb3b338dd6f68af4feaa75ba099a5694f81d6..139e54f7c58cd91799c997dd6d4f659e564444b1 100644 --- a/src/language/node_processor/CFunctionProcessor.hpp +++ b/src/language/node_processor/CFunctionProcessor.hpp @@ -7,6 +7,12 @@ #include <cmath> +#include <cmath> +#include <functional> +#include <iostream> +#include <tuple> +#include <vector> + template <typename ProvidedValueType, typename ExpectedValueType> class CFunctionArgumentProcessor final : public INodeProcessor { @@ -32,6 +38,73 @@ class CFunctionArgumentProcessor final : public INodeProcessor {} }; +struct IFunctionEmbedder +{ + virtual void apply(const std::vector<ASTNodeDataVariant>& x, double& f_x) = 0; +}; + +template <typename FX, typename... Args> +class FunctionEmbedder : public IFunctionEmbedder +{ + private: + std::function<FX(Args...)> m_f; + using ArgsTuple = std::tuple<Args...>; + + template <size_t I> + void + _copy_value(ArgsTuple& t, const std::vector<ASTNodeDataVariant>& v) const + { + std::visit( + [&](auto v_i) { + if constexpr (std::is_arithmetic_v<decltype(v_i)>) { + std::get<I>(t) = v_i; + } else { + std::cerr << __FILE__ << ':' << __LINE__ << ": unexpected argument type!\n"; + std::exit(1); + } + }, + v[I]); + } + + template <size_t... I> + void + _copy_from_vector(ArgsTuple& t, const std::vector<ASTNodeDataVariant>& v, std::index_sequence<I...>) const + { + (_copy_value<I>(t, v), ...); + } + + public: + // @warning This is written in a template fashion to ensure that function type + // is correct. If one uses simply FunctionEmbedder(std::function<FX(Args...)>&&), + // types seem unchecked + template <typename FX2, typename... Args2> + FunctionEmbedder(std::function<FX2(Args2...)> f) : m_f(f) + { + static_assert(std::is_same_v<FX, FX2>, "incorrect return type"); + static_assert(sizeof...(Args) == sizeof...(Args2), "invalid number of arguments"); + using Args2Tuple = std::tuple<Args2...>; + static_assert(std::is_same_v<ArgsTuple, Args2Tuple>, "invalid arguments type"); + } + + PUGS_INLINE + size_t + numberOfArguments() const + { + return sizeof...(Args); + } + + void + apply(const std::vector<ASTNodeDataVariant>& x, FX& f_x) final + { + constexpr size_t N = std::tuple_size_v<ArgsTuple>; + ArgsTuple t; + using IndexSequence = std::make_index_sequence<N>; + + this->_copy_from_vector(t, x, IndexSequence{}); + f_x = std::apply(m_f, t); + } +}; + template <typename ReturnType, typename ExpressionValueType> class CFunctionExpressionProcessor final : public INodeProcessor { @@ -40,36 +113,27 @@ class CFunctionExpressionProcessor final : public INodeProcessor std::vector<ASTNodeDataVariant>& m_argument_values; + std::unique_ptr<IFunctionEmbedder> m_embedded_function; + public: void - execute(ExecUntilBreakOrContinue& exec_policy) + execute(ExecUntilBreakOrContinue&) { + ReturnType result; + m_embedded_function->apply(m_argument_values, result); if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { - std::visit( - [&](auto v) { - if constexpr (std::is_arithmetic_v<decltype(v)>) { - m_node.m_value = std::sin(v); - } else { - throw parse_error("invalid C function evaluation", m_node.begin()); - } - }, - m_argument_values[0]); + m_node.m_value = result; } else { - std::visit( - [&](auto v) { - if constexpr (std::is_arithmetic_v<decltype(v)>) { - m_node.m_value = static_cast<ReturnType>(std::sin(v)); - } else { - throw parse_error("invalid C function evaluation", m_node.begin()); - } - }, - m_argument_values[0]); + m_node.m_value = static_cast<ExpressionValueType>(result); } } CFunctionExpressionProcessor(ASTNode& node, std::vector<ASTNodeDataVariant>& argument_values) : m_node{node}, m_argument_values{argument_values} - {} + { + m_embedded_function = std::make_unique<FunctionEmbedder<double, double, double>>( + std::function{[](double x, double y) -> double { return std::sin(x) * std::cos(3 * x) + y; }}); + } }; class CFunctionProcessor : public INodeProcessor