Skip to content
Snippets Groups Projects
Select Git revision
  • beb384d876853caa5c77c10d3648254ff10e7135
  • develop default protected
  • feature/variational-hydro
  • origin/stage/bouguettaia
  • feature/gmsh-reader
  • feature/reconstruction
  • save_clemence
  • feature/kinetic-schemes
  • feature/local-dt-fsi
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master protected
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

SchemeModule.cpp

Blame
  • PugsFunctionAdapter.hpp 11.54 KiB
    #ifndef PUGS_FUNCTION_ADAPTER_HPP
    #define PUGS_FUNCTION_ADAPTER_HPP
    
    #include <language/ast/ASTNode.hpp>
    #include <language/node_processor/ExecutionPolicy.hpp>
    #include <language/utils/ASTNodeDataType.hpp>
    #include <language/utils/ASTNodeDataTypeTraits.hpp>
    #include <language/utils/SymbolTable.hpp>
    #include <utils/Exceptions.hpp>
    #include <utils/PugsMacros.hpp>
    #include <utils/SmallArray.hpp>
    
    #include <Kokkos_Core.hpp>
    
    #include <array>
    
    template <typename T>
    class PugsFunctionAdapter;
    template <typename OutputType, typename... InputType>
    class PugsFunctionAdapter<OutputType(InputType...)>
    {
     protected:
      using InputTuple              = std::tuple<std::decay_t<InputType>...>;
      constexpr static size_t NArgs = std::tuple_size_v<InputTuple>;
    
     private:
      template <typename T, typename... Args>
      PUGS_INLINE static void
      _convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args)
      {
        context[i_context++] = t;
        if constexpr (sizeof...(Args) > 0) {
          _convertArgs(context, i_context, std::forward<Args>(args)...);
        }
      }
    
      template <size_t I>
      [[nodiscard]] PUGS_INLINE static bool
      _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
      {
        using Arg = std::tuple_element_t<I, InputTuple>;
    
        constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;
    
        Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
        const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();
    
        return isNaturalConversion(expected_input_data_type, arg_data_type);
      }
    
      template <size_t... I>
      [[nodiscard]] PUGS_INLINE static bool
      _checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>)
      {
        Assert(NArgs == input_expression.children.size());
        return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...);
      }
    
      [[nodiscard]] PUGS_INLINE static bool
      _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
      {
        if constexpr (NArgs == 1) {
          return _checkValidArgumentDataType<0>(input_domain_expression);
        } else {
          if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
              (input_domain_expression.children.size() != NArgs)) {
            return false;
          }
    
          using IndexSequence = std::make_index_sequence<NArgs>;
          return _checkAllInputDataType(input_domain_expression, IndexSequence{});
        }
      }
    
      [[nodiscard]] PUGS_INLINE static bool
      _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
      {
        constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
        const ASTNodeDataType& return_data_type                    = output_domain_expression.m_data_type.contentType();
    
        return isNaturalConversion(return_data_type, expected_return_data_type);
      }
    
      template <typename Arg, typename... RemainingArgs>
      [[nodiscard]] PUGS_INLINE static std::string
      _getCompoundTypeName()
      {
        if constexpr (sizeof...(RemainingArgs) > 0) {
          return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>();
        } else {
          return dataTypeName(ast_node_data_type_from<Arg>);
        }
      }
    
      [[nodiscard]] static std::string
      _getInputDataTypeName()
      {
        return _getCompoundTypeName<InputType...>();
      }
    
      PUGS_INLINE static void
      _checkFunction(const FunctionDescriptor& function)
      {
        bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
        bool has_valid_output       = _checkValidOutputDomain(*function.domainMappingNode().children[1]);
    
        if (not(has_valid_input_domain and has_valid_output)) {
          std::ostringstream error_message;
          error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
                        << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
                        << rang::style::reset << '\n'
                        << "note: provided function " << rang::fgB::magenta << function.name() << ": "
                        << function.domainMappingNode().string() << rang::style::reset;
          throw NormalError(error_message.str());
        }
      }
    
     protected:
      [[nodiscard]] PUGS_INLINE static auto&
      getFunctionExpression(const FunctionSymbolId& function_symbol_id)
      {
        auto& function_descriptor = function_symbol_id.descriptor();
        _checkFunction(function_descriptor);
    
        return *function_descriptor.definitionNode().children[1];
      }
    
      [[nodiscard]] PUGS_INLINE static auto
      getContextList(const ASTNode& expression)
      {
        SmallArray<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
        auto& context = expression.m_symbol_table->context();
    
        for (size_t i = 0; i < context_list.size(); ++i) {
          context_list[i] =
            ExecutionPolicy(ExecutionPolicy{},
                            {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
        }
    
        return context_list;
      }
    
      template <typename... Args>
      PUGS_INLINE static void
      convertArgs(ExecutionPolicy::Context& context, Args&&... args)
      {
        static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>,
                      "unexpected input type");
        _convertArgs(context, 0, args...);
      }
    
      [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
      getResultConverter(const ASTNodeDataType& data_type)
      {
        if constexpr (is_tiny_vector_v<OutputType>) {
          switch (data_type) {
          case ASTNodeDataType::vector_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
          }
          case ASTNodeDataType::bool_t: {
            if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
              return
                [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
          case ASTNodeDataType::unsigned_int_t: {
            if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
              return [](DataVariant&& result) -> OutputType {
                return OutputType(static_cast<double>(std::get<uint64_t>(result)));
              };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
          case ASTNodeDataType::int_t: {
            if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
              return [](DataVariant&& result) -> OutputType {
                return OutputType{static_cast<double>(std::get<int64_t>(result))};
              };
            } else {
              // If this point is reached must be a 0 vector
              return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
            }
          }
          case ASTNodeDataType::double_t: {
            if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
              return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                  dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          }
            // LCOV_EXCL_STOP
          }
        } else if constexpr (is_tiny_matrix_v<OutputType>) {
          switch (data_type) {
          case ASTNodeDataType::matrix_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
          }
          case ASTNodeDataType::bool_t: {
            if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
              return
                [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
          case ASTNodeDataType::unsigned_int_t: {
            if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
              return [](DataVariant&& result) -> OutputType {
                return OutputType(static_cast<double>(std::get<uint64_t>(result)));
              };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
          case ASTNodeDataType::int_t: {
            if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
              return [](DataVariant&& result) -> OutputType {
                return OutputType{static_cast<double>(std::get<int64_t>(result))};
              };
            } else {
              // If this point is reached must be a 0 matrix
              return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
            }
          }
          case ASTNodeDataType::double_t: {
            if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
              return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                    dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
              // LCOV_EXCL_STOP
            }
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                  dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          }
            // LCOV_EXCL_STOP
          }
        } else if constexpr (std::is_arithmetic_v<OutputType>) {
          switch (data_type) {
          case ASTNodeDataType::bool_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<bool>(result); };
          }
          case ASTNodeDataType::unsigned_int_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<uint64_t>(result); };
          }
          case ASTNodeDataType::int_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<int64_t>(result); };
          }
          case ASTNodeDataType::double_t: {
            return [](DataVariant&& result) -> OutputType { return std::get<double>(result); };
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                  dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          }
            // LCOV_EXCL_STOP
          }
        } else {
          static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type");
        }
      }
    
      PugsFunctionAdapter() = delete;
    };
    
    #endif   // PUGS_FUNCTION_ADAPTER_HPP