Skip to content
Snippets Groups Projects
Select Git revision
  • 4709617039d002f4c5b96763e744662177926b55
  • develop default protected
  • feature/variational-hydro
  • origin/stage/bouguettaia
  • feature/gmsh-reader
  • feature/reconstruction
  • save_clemence
  • feature/kinetic-schemes
  • feature/local-dt-fsi
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master protected
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

FunctionProcessor.hpp

Blame
  • Stéphane Del Pino's avatar
    Stéphane Del Pino authored
    It avoids some spurious automatic conversion (especially from
    arithmetic types to TinyVector<1> or TinyMatrix<1>) and could reduce
    stress on the compiler.
    47096170
    History
    FunctionProcessor.hpp 5.15 KiB
    #ifndef FUNCTION_PROCESSOR_HPP
    #define FUNCTION_PROCESSOR_HPP
    
    #include <language/PEGGrammar.hpp>
    #include <language/node_processor/ASTNodeExpressionListProcessor.hpp>
    #include <language/node_processor/FunctionArgumentConverter.hpp>
    #include <language/node_processor/INodeProcessor.hpp>
    #include <language/utils/FunctionTable.hpp>
    #include <language/utils/SymbolTable.hpp>
    
    template <typename ReturnType, typename ExpressionValueType>
    class FunctionExpressionProcessor final : public INodeProcessor
    {
     private:
      ASTNode& m_function_expression;
    
     public:
      DataVariant
      execute(ExecutionPolicy& exec_policy)
      {
        if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) {
          return m_function_expression.execute(exec_policy);
        } else if constexpr (std::is_same_v<AggregateDataVariant, ExpressionValueType>) {
          static_assert(is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>, "unexpected return type");
          ReturnType return_value{};
          auto value = std::get<ExpressionValueType>(m_function_expression.execute(exec_policy));
          if constexpr (is_tiny_vector_v<ReturnType>) {
            for (size_t i = 0; i < ReturnType::Dimension; ++i) {
              std::visit(
                [&](auto&& vi) {
                  using Vi_T = std::decay_t<decltype(vi)>;
                  if constexpr (std::is_convertible_v<Vi_T, double>) {
                    return_value[i] = vi;
                  }
                },
                value[i]);
            }
          } else {
            static_assert(is_tiny_matrix_v<ReturnType>);
    
            for (size_t i = 0, l = 0; i < return_value.numberOfRows(); ++i) {
              for (size_t j = 0; j < return_value.numberOfColumns(); ++j, ++l) {
                std::visit(
                  [&](auto&& Aij) {
                    using Vi_T = std::decay_t<decltype(Aij)>;
                    if constexpr (std::is_convertible_v<Vi_T, double>) {
                      return_value(i, j) = Aij;
                    }
                  },
                  value[l]);
              }
            }
          }
          return return_value;
        } else if constexpr (std::is_same_v<ReturnType, std::string>) {
          return std::to_string(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
        } else if constexpr (std::is_same_v<ExpressionValueType, ZeroType>) {
          return ReturnType{ZeroType::zero};
        } else if constexpr (std::is_convertible_v<ExpressionValueType, ReturnType>) {
          return static_cast<ReturnType>(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
        } else if constexpr (std::is_arithmetic_v<ExpressionValueType> and
                             (is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>)) {
          static_assert(ReturnType::Dimension == 1, "invalid conversion");
          return ReturnType(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
        } else {
          throw UnexpectedError("invalid conversion");
        }
      }
    
      FunctionExpressionProcessor(ASTNode& function_component_expression)
        : m_function_expression{function_component_expression}
      {}
    };
    
    class FunctionProcessor : public INodeProcessor
    {
     private:
      ASTNode& m_argument_node;
    
      const size_t m_context_size;
      const int32_t m_context_id;
    
      std::vector<std::unique_ptr<IFunctionArgumentConverter>> m_argument_converters;
      std::vector<std::unique_ptr<INodeProcessor>> m_function_expression_processors;
    
     public:
      void
      addArgumentConverter(std::unique_ptr<IFunctionArgumentConverter>&& argument_converter)
      {
        m_argument_converters.emplace_back(std::move(argument_converter));
      }
    
      void
      addFunctionExpressionProcessor(std::unique_ptr<INodeProcessor>&& function_processor)
      {
        m_function_expression_processors.emplace_back(std::move(function_processor));
      }
    
      DataVariant
      execute(ExecutionPolicy& exec_policy)
      {
        // Context is built in each execution for thread safety: multiple thread can call a function at once
        ExecutionPolicy::Context context{m_context_id, std::make_shared<ExecutionPolicy::Context::Values>(m_context_size)};
    
        ExecutionPolicy context_exec_policy{exec_policy, context};
    
        if (m_argument_converters.size() == 1) {
          m_argument_converters[0]->convert(context_exec_policy, m_argument_node.execute(context_exec_policy));
        } else {
          AggregateDataVariant argument_values{
            std::get<AggregateDataVariant>(m_argument_node.execute(context_exec_policy))};
    
          for (size_t i = 0; i < m_argument_converters.size(); ++i) {
            m_argument_converters[i]->convert(context_exec_policy, std::move(argument_values[i]));
          }
        }
    
        if (m_function_expression_processors.size() == 1) {
          return m_function_expression_processors[0]->execute(context_exec_policy);
        } else {
          std::vector<DataVariant> list_values;
          list_values.reserve(m_function_expression_processors.size());
    
          for (auto& function_expression_processor : m_function_expression_processors) {
            list_values.emplace_back(function_expression_processor->execute(context_exec_policy));
          }
          return DataVariant{std::move(list_values)};
        }
      }
    
      FunctionProcessor(ASTNode& argument_node, SymbolTable::Context context)
        : m_argument_node{argument_node}, m_context_size{context.size()}, m_context_id{context.id()}
      {}
    };
    
    #endif   // FUNCTION_PROCESSOR_HPP