#ifndef PUGS_FUNCTION_ADAPTER_HPP #define PUGS_FUNCTION_ADAPTER_HPP #include <language/ast/ASTNode.hpp> #include <language/ast/ASTNodeDataType.hpp> #include <language/node_processor/ExecutionPolicy.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/SymbolTable.hpp> #include <utils/Array.hpp> #include <utils/Exceptions.hpp> #include <utils/PugsMacros.hpp> #include <Kokkos_Core.hpp> #include <array> template <typename T> class PugsFunctionAdapter; template <typename OutputType, typename... InputType> class PugsFunctionAdapter<OutputType(InputType...)> { protected: using InputTuple = std::tuple<std::decay_t<InputType>...>; constexpr static size_t NArgs = std::tuple_size_v<InputTuple>; private: template <typename T, typename... Args> PUGS_INLINE static void _convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args) { context[i_context++] = t; if constexpr (sizeof...(Args) > 0) { _convertArgs(context, i_context, std::forward<Args>(args)...); } } template <size_t I> [[nodiscard]] PUGS_INLINE static bool _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT) { using Arg = std::tuple_element_t<I, InputTuple>; constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>; Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t); const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType(); return isNaturalConversion(expected_input_data_type, arg_data_type); } template <size_t... I> [[nodiscard]] PUGS_INLINE static bool _checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>) { Assert(NArgs == input_expression.children.size()); return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...); } [[nodiscard]] PUGS_INLINE static bool _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept { if constexpr (NArgs == 1) { return _checkValidArgumentDataType<0>(input_domain_expression); } else { if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or (input_domain_expression.children.size() != NArgs)) { return false; } using IndexSequence = std::make_index_sequence<NArgs>; return _checkAllInputDataType(input_domain_expression, IndexSequence{}); } } [[nodiscard]] PUGS_INLINE static bool _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT) { constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>; const ASTNodeDataType& return_data_type = output_domain_expression.m_data_type.contentType(); return isNaturalConversion(return_data_type, expected_return_data_type); } template <typename Arg, typename... RemainingArgs> [[nodiscard]] PUGS_INLINE static std::string _getCompoundTypeName() { if constexpr (sizeof...(RemainingArgs) > 0) { return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>(); } else { return dataTypeName(ast_node_data_type_from<Arg>); } } [[nodiscard]] static std::string _getInputDataTypeName() { return _getCompoundTypeName<InputType...>(); } PUGS_INLINE static void _checkFunction(const FunctionDescriptor& function) { bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]); bool has_valid_output = _checkValidOutputDomain(*function.domainMappingNode().children[1]); if (not(has_valid_input_domain and has_valid_output)) { std::ostringstream error_message; error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>) << rang::style::reset << '\n' << "note: provided function " << rang::fgB::magenta << function.name() << ": " << function.domainMappingNode().string() << rang::style::reset << std::ends; throw NormalError(error_message.str()); } } protected: [[nodiscard]] PUGS_INLINE static auto& getFunctionExpression(const FunctionSymbolId& function_symbol_id) { auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()]; _checkFunction(function); return *function.definitionNode().children[1]; } [[nodiscard]] PUGS_INLINE static auto getContextList(const ASTNode& expression) { Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size()); auto& context = expression.m_symbol_table->context(); for (size_t i = 0; i < context_list.size(); ++i) { context_list[i] = ExecutionPolicy(ExecutionPolicy{}, {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())}); } return context_list; } template <typename... Args> PUGS_INLINE static void convertArgs(ExecutionPolicy::Context& context, Args&&... args) { static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>, "unexpected input type"); _convertArgs(context, 0, args...); } [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)> getResultConverter(const ASTNodeDataType& data_type) { if constexpr (is_tiny_vector_v<OutputType>) { switch (data_type) { case ASTNodeDataType::list_t: { return [](DataVariant&& result) -> OutputType { AggregateDataVariant& v = std::get<AggregateDataVariant>(result); OutputType x; for (size_t i = 0; i < x.dimension(); ++i) { std::visit( [&](auto&& vi) { using Vi_T = std::decay_t<decltype(vi)>; if constexpr (std::is_arithmetic_v<Vi_T>) { x[i] = vi; } else { // LCOV_EXCL_START throw UnexpectedError("expecting arithmetic value"); // LCOV_EXCL_STOP } }, v[i]); } return x; }; } case ASTNodeDataType::vector_t: { return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; } case ASTNodeDataType::bool_t: { if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { return [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; }; } else { // LCOV_EXCL_START throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); // LCOV_EXCL_STOP } } case ASTNodeDataType::unsigned_int_t: { if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { return [](DataVariant&& result) -> OutputType { return OutputType(static_cast<double>(std::get<uint64_t>(result))); }; } else { // LCOV_EXCL_START throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); // LCOV_EXCL_STOP } } case ASTNodeDataType::int_t: { if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { return [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<int64_t>(result))}; }; } else { // If this point is reached must be a 0 vector return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; }; } } case ASTNodeDataType::double_t: { if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; } else { // LCOV_EXCL_START throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); // LCOV_EXCL_STOP } } // LCOV_EXCL_START default: { throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); } // LCOV_EXCL_STOP } } else if constexpr (std::is_arithmetic_v<OutputType>) { switch (data_type) { case ASTNodeDataType::bool_t: { return [](DataVariant&& result) -> OutputType { return std::get<bool>(result); }; } case ASTNodeDataType::unsigned_int_t: { return [](DataVariant&& result) -> OutputType { return std::get<uint64_t>(result); }; } case ASTNodeDataType::int_t: { return [](DataVariant&& result) -> OutputType { return std::get<int64_t>(result); }; } case ASTNodeDataType::double_t: { return [](DataVariant&& result) -> OutputType { return std::get<double>(result); }; } // LCOV_EXCL_START default: { throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" + dataTypeName(ast_node_data_type_from<OutputType>) + "\""); } // LCOV_EXCL_STOP } } else { static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type"); } } PugsFunctionAdapter() = delete; }; #endif // PUGS_FUNCTION_ADAPTER_HPP