Select Git revision
BuildInfo.cpp
PugsFunctionAdapter.hpp 11.54 KiB
#ifndef PUGS_FUNCTION_ADAPTER_HPP
#define PUGS_FUNCTION_ADAPTER_HPP
#include <language/ast/ASTNode.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/ASTNodeDataType.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/SmallArray.hpp>
#include <Kokkos_Core.hpp>
#include <array>
template <typename T>
class PugsFunctionAdapter;
template <typename OutputType, typename... InputType>
class PugsFunctionAdapter<OutputType(InputType...)>
{
protected:
using InputTuple = std::tuple<std::decay_t<InputType>...>;
constexpr static size_t NArgs = std::tuple_size_v<InputTuple>;
private:
template <typename T, typename... Args>
PUGS_INLINE static void
_convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args)
{
context[i_context++] = t;
if constexpr (sizeof...(Args) > 0) {
_convertArgs(context, i_context, std::forward<Args>(args)...);
}
}
template <size_t I>
[[nodiscard]] PUGS_INLINE static bool
_checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
{
using Arg = std::tuple_element_t<I, InputTuple>;
constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;
Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();
return isNaturalConversion(expected_input_data_type, arg_data_type);
}
template <size_t... I>
[[nodiscard]] PUGS_INLINE static bool
_checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>)
{
Assert(NArgs == input_expression.children.size());
return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...);
}
[[nodiscard]] PUGS_INLINE static bool
_checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
{
if constexpr (NArgs == 1) {
return _checkValidArgumentDataType<0>(input_domain_expression);
} else {
if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
(input_domain_expression.children.size() != NArgs)) {
return false;
}
using IndexSequence = std::make_index_sequence<NArgs>;
return _checkAllInputDataType(input_domain_expression, IndexSequence{});
}
}
[[nodiscard]] PUGS_INLINE static bool
_checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
{
constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
const ASTNodeDataType& return_data_type = output_domain_expression.m_data_type.contentType();
return isNaturalConversion(return_data_type, expected_return_data_type);
}
template <typename Arg, typename... RemainingArgs>
[[nodiscard]] PUGS_INLINE static std::string
_getCompoundTypeName()
{
if constexpr (sizeof...(RemainingArgs) > 0) {
return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>();
} else {
return dataTypeName(ast_node_data_type_from<Arg>);
}
}
[[nodiscard]] static std::string
_getInputDataTypeName()
{
return _getCompoundTypeName<InputType...>();
}
PUGS_INLINE static void
_checkFunction(const FunctionDescriptor& function)
{
bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
bool has_valid_output = _checkValidOutputDomain(*function.domainMappingNode().children[1]);
if (not(has_valid_input_domain and has_valid_output)) {
std::ostringstream error_message;
error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
<< _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
<< rang::style::reset << '\n'
<< "note: provided function " << rang::fgB::magenta << function.name() << ": "
<< function.domainMappingNode().string() << rang::style::reset;
throw NormalError(error_message.str());
}
}
protected:
[[nodiscard]] PUGS_INLINE static auto&
getFunctionExpression(const FunctionSymbolId& function_symbol_id)
{
auto& function_descriptor = function_symbol_id.descriptor();
_checkFunction(function_descriptor);
return *function_descriptor.definitionNode().children[1];
}
[[nodiscard]] PUGS_INLINE static auto
getContextList(const ASTNode& expression)
{
SmallArray<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
auto& context = expression.m_symbol_table->context();
for (size_t i = 0; i < context_list.size(); ++i) {
context_list[i] =
ExecutionPolicy(ExecutionPolicy{},
{context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
}
return context_list;
}
template <typename... Args>
PUGS_INLINE static void
convertArgs(ExecutionPolicy::Context& context, Args&&... args)
{
static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>,
"unexpected input type");
_convertArgs(context, 0, args...);
}
[[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
getResultConverter(const ASTNodeDataType& data_type)
{
if constexpr (is_tiny_vector_v<OutputType>) {
switch (data_type) {
case ASTNodeDataType::vector_t: {
return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
}
case ASTNodeDataType::bool_t: {
if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
return
[](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::unsigned_int_t: {
if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
return [](DataVariant&& result) -> OutputType {
return OutputType(static_cast<double>(std::get<uint64_t>(result)));
};
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::int_t: {
if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
return [](DataVariant&& result) -> OutputType {
return OutputType{static_cast<double>(std::get<int64_t>(result))};
};
} else {
// If this point is reached must be a 0 vector
return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
}
}
case ASTNodeDataType::double_t: {
if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
}
// LCOV_EXCL_STOP
}
} else if constexpr (is_tiny_matrix_v<OutputType>) {
switch (data_type) {
case ASTNodeDataType::matrix_t: {
return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
}
case ASTNodeDataType::bool_t: {
if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
return
[](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::unsigned_int_t: {
if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
return [](DataVariant&& result) -> OutputType {
return OutputType(static_cast<double>(std::get<uint64_t>(result)));
};
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
case ASTNodeDataType::int_t: {
if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
return [](DataVariant&& result) -> OutputType {
return OutputType{static_cast<double>(std::get<int64_t>(result))};
};
} else {
// If this point is reached must be a 0 matrix
return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
}
}
case ASTNodeDataType::double_t: {
if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
// LCOV_EXCL_STOP
}
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
}
// LCOV_EXCL_STOP
}
} else if constexpr (std::is_arithmetic_v<OutputType>) {
switch (data_type) {
case ASTNodeDataType::bool_t: {
return [](DataVariant&& result) -> OutputType { return std::get<bool>(result); };
}
case ASTNodeDataType::unsigned_int_t: {
return [](DataVariant&& result) -> OutputType { return std::get<uint64_t>(result); };
}
case ASTNodeDataType::int_t: {
return [](DataVariant&& result) -> OutputType { return std::get<int64_t>(result); };
}
case ASTNodeDataType::double_t: {
return [](DataVariant&& result) -> OutputType { return std::get<double>(result); };
}
// LCOV_EXCL_START
default: {
throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
}
// LCOV_EXCL_STOP
}
} else {
static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type");
}
}
PugsFunctionAdapter() = delete;
};
#endif // PUGS_FUNCTION_ADAPTER_HPP