Skip to content
Snippets Groups Projects
Select Git revision
  • b7bf9f92c61c313c6b29936339eba058bfebcd79
  • develop default protected
  • feature/advection
  • feature/composite-scheme-other-fluxes
  • origin/stage/bouguettaia
  • save_clemence
  • feature/local-dt-fsi
  • feature/variational-hydro
  • feature/gmsh-reader
  • feature/reconstruction
  • feature/kinetic-schemes
  • feature/composite-scheme-sources
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

ASTNodeFunctionExpressionBuilder.cpp

Blame
  • ASTNodeFunctionExpressionBuilder.cpp 10.95 KiB
    #include <ASTNodeFunctionExpressionBuilder.hpp>
    #include <PEGGrammar.hpp>
    
    #include <FunctionTable.hpp>
    #include <SymbolTable.hpp>
    
    #include <ASTNodeDataTypeFlattener.hpp>
    #include <ASTNodeNaturalConversionChecker.hpp>
    
    #include <node_processor/FunctionProcessor.hpp>
    
    template <typename SymbolType>
    std::unique_ptr<IFunctionArgumentConverter>
    ASTNodeFunctionExpressionBuilder::_getArgumentConverter(SymbolType& parameter_symbol,
                                                            const ASTNodeSubDataType& node_sub_data_type)
    {
      const size_t parameter_id = std::get<size_t>(parameter_symbol.attributes().value());
    
      ASTNodeNaturalConversionChecker{node_sub_data_type.m_parent_node, node_sub_data_type.m_data_type,
                                      parameter_symbol.attributes().dataType()};
    
      auto get_function_argument_converter_for =
        [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> {
        using ParameterT = std::decay_t<decltype(parameter_v)>;
        switch (node_sub_data_type.m_data_type) {
        case ASTNodeDataType::bool_t: {
          return std::make_unique<FunctionArgumentConverter<ParameterT, bool>>(parameter_id);
        }
        case ASTNodeDataType::unsigned_int_t: {
          return std::make_unique<FunctionArgumentConverter<ParameterT, uint64_t>>(parameter_id);
        }
        case ASTNodeDataType::int_t: {
          return std::make_unique<FunctionArgumentConverter<ParameterT, int64_t>>(parameter_id);
        }
        case ASTNodeDataType::double_t: {
          return std::make_unique<FunctionArgumentConverter<ParameterT, double>>(parameter_id);
        }
          // LCOV_EXCL_START
        default: {
          throw parse_error("invalid argument type", std::vector{node_sub_data_type.m_parent_node.begin()});
        }
          // LCOV_EXCL_STOP
        }
      };
    
      auto get_function_argument_converter_for_string = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
        switch (node_sub_data_type.m_data_type) {
        case ASTNodeDataType::bool_t: {
          return std::make_unique<FunctionArgumentConverter<std::string, bool>>(parameter_id);
        }
        case ASTNodeDataType::unsigned_int_t: {
          return std::make_unique<FunctionArgumentConverter<std::string, uint64_t>>(parameter_id);
        }
        case ASTNodeDataType::int_t: {
          return std::make_unique<FunctionArgumentConverter<std::string, int64_t>>(parameter_id);
        }
        case ASTNodeDataType::double_t: {
          return std::make_unique<FunctionArgumentConverter<std::string, double>>(parameter_id);
        }
        case ASTNodeDataType::string_t: {
          return std::make_unique<FunctionArgumentConverter<std::string, std::string>>(parameter_id);
        }
          // LCOV_EXCL_START
        default: {
          throw parse_error("invalid argument type", std::vector{node_sub_data_type.m_parent_node.begin()});
        }
          // LCOV_EXCL_STOP
        }
      };
    
      auto get_function_argument_converter_for_parameter_type = [&]() {
        switch (parameter_symbol.attributes().dataType()) {
        case ASTNodeDataType::bool_t: {
          return get_function_argument_converter_for(bool{});
        }
        case ASTNodeDataType::unsigned_int_t: {
          return get_function_argument_converter_for(uint64_t{});
        }
        case ASTNodeDataType::int_t: {
          return get_function_argument_converter_for(int64_t{});
        }
        case ASTNodeDataType::double_t: {
          return get_function_argument_converter_for(double{});
        }
        case ASTNodeDataType::string_t: {
          return get_function_argument_converter_for_string();
        }
    
          // LCOV_EXCL_START
        default: {
          throw parse_error("unexpected error: undefined parameter type", std::vector{m_node.begin()});
        }
          // LCOV_EXCL_STOP
        }
      };
    
      return get_function_argument_converter_for_parameter_type();
    }
    
    void
    ASTNodeFunctionExpressionBuilder::_storeArgumentConverter(ASTNode& parameter_variable,
                                                              ASTNodeSubDataType& node_sub_data_type,
                                                              FunctionProcessor& function_processor)
    {
      Assert(parameter_variable.is_type<language::name>(), "unexpected parameter type!");
    
      auto [i_parameter_symbol, found] =
        parameter_variable.m_symbol_table->find(parameter_variable.string(), parameter_variable.begin());
      Assert(found);
    
      function_processor.addArgumentConverter(this->_getArgumentConverter(*i_parameter_symbol, node_sub_data_type));
    }
    
    std::unique_ptr<FunctionProcessor>
    ASTNodeFunctionExpressionBuilder::_buildArgumentConverter(FunctionDescriptor& function_descriptor, ASTNode& node)
    {
      ASTNode& function_expression = *function_descriptor.definitionNode().children[1];
    
      Assert(function_expression.m_symbol_table->hasContext());
      const SymbolTable::Context& context = function_expression.m_symbol_table->context();
    
      const ASTNode& definition_node = function_descriptor.definitionNode();
      ASTNode& parameter_variables   = *definition_node.children[0];
    
      ASTNode& argument_nodes = *node.children[1];
    
      std::unique_ptr function_processor = std::make_unique<FunctionProcessor>(argument_nodes, context);
    
      ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
      ASTNodeDataTypeFlattener{argument_nodes, flattened_datatype_list};
    
      const size_t arguments_number = flattened_datatype_list.size();
    
      const size_t parameters_number =
        parameter_variables.is_type<language::name_list>() ? parameter_variables.children.size() : 1;
    
      if (arguments_number != parameters_number) {
        std::ostringstream error_message;
        error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
                      << rang::style::reset << ", provided " << rang::fgB::yellow << arguments_number << rang::style::reset;
        throw parse_error(error_message.str(), argument_nodes.begin());
      }
    
      if (arguments_number > 1) {
        for (size_t i = 0; i < arguments_number; ++i) {
          ASTNode& parameter_variable = *parameter_variables.children[i];
          this->_storeArgumentConverter(parameter_variable, flattened_datatype_list[i], *function_processor);
        }
      } else {
        this->_storeArgumentConverter(parameter_variables, flattened_datatype_list[0], *function_processor);
      }
    
      return function_processor;
    }
    
    std::unique_ptr<INodeProcessor>
    ASTNodeFunctionExpressionBuilder::_getFunctionProcessor(const ASTNodeDataType expression_value_type,
                                                            const ASTNodeDataType return_value_type,
                                                            ASTNode& node,
                                                            ASTNode& function_component_expression)
    {
      auto get_function_processor_for_expression_value = [&](const auto& return_v) -> std::unique_ptr<INodeProcessor> {
        using ReturnT = std::decay_t<decltype(return_v)>;
        switch (expression_value_type) {
        case ASTNodeDataType::bool_t: {
          return std::make_unique<FunctionExpressionProcessor<ReturnT, bool>>(function_component_expression);
        }
        case ASTNodeDataType::unsigned_int_t: {
          return std::make_unique<FunctionExpressionProcessor<ReturnT, uint64_t>>(function_component_expression);
        }
        case ASTNodeDataType::int_t: {
          return std::make_unique<FunctionExpressionProcessor<ReturnT, int64_t>>(function_component_expression);
        }
        case ASTNodeDataType::double_t: {
          return std::make_unique<FunctionExpressionProcessor<ReturnT, double>>(function_component_expression);
        }
        case ASTNodeDataType::string_t: {
          if constexpr (std::is_same_v<ReturnT, std::string>) {
            return std::make_unique<FunctionExpressionProcessor<ReturnT, std::string>>(function_component_expression);
          } else {
            throw parse_error("invalid string conversion", std::vector{node.children[1]->begin()});
          }
        }
          // LCOV_EXCL_START
        default: {
          throw parse_error("unexpected error: undefined expression value type for function",
                            std::vector{node.children[1]->begin()});
        }
          // LCOV_EXCL_STOP
        }
      };
    
      auto get_function_processor_for_value = [&]() {
        switch (return_value_type) {
        case ASTNodeDataType::bool_t: {
          return get_function_processor_for_expression_value(bool{});
        }
        case ASTNodeDataType::unsigned_int_t: {
          return get_function_processor_for_expression_value(uint64_t{});
        }
        case ASTNodeDataType::int_t: {
          return get_function_processor_for_expression_value(int64_t{});
        }
        case ASTNodeDataType::double_t: {
          return get_function_processor_for_expression_value(double{});
        }
        case ASTNodeDataType::string_t: {
          return get_function_processor_for_expression_value(std::string{});
        }
          // LCOV_EXCL_START
        default: {
          throw parse_error("unexpected error: undefined return type for function", std::vector{node.begin()});
        }
          // LCOV_EXCL_STOP
        }
      };
    
      return get_function_processor_for_value();
    }
    
    ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node) : m_node(node)
    {
      auto [i_function_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin());
      Assert(found);
      Assert(i_function_symbol->attributes().dataType() == ASTNodeDataType::function_t);
    
      uint64_t function_id = std::get<uint64_t>(i_function_symbol->attributes().value());
    
      FunctionDescriptor& function_descriptor = node.m_symbol_table->functionTable()[function_id];
    
      std::unique_ptr function_processor = this->_buildArgumentConverter(function_descriptor, node);
    
      auto add_component_expression = [&](ASTNode& expression_node, ASTNode& domain_node) {
        ASTNodeDataType expression_value_type = expression_node.m_data_type;
        ASTNodeDataType return_value_type     = ASTNodeDataType::undefined_t;
    
        ASTNode& image_domain_node = domain_node;
    
        if (image_domain_node.is_type<language::B_set>()) {
          return_value_type = ASTNodeDataType::bool_t;
        } else if (image_domain_node.is_type<language::Z_set>()) {
          return_value_type = ASTNodeDataType::int_t;
        } else if (image_domain_node.is_type<language::N_set>()) {
          return_value_type = ASTNodeDataType::unsigned_int_t;
        } else if (image_domain_node.is_type<language::R_set>()) {
          return_value_type = ASTNodeDataType::double_t;
        } else if (image_domain_node.is_type<language::string_type>()) {
          return_value_type = ASTNodeDataType::string_t;
        }
    
        Assert(return_value_type != ASTNodeDataType::undefined_t);
    
        function_processor->addFunctionExpressionProcessor(
          this->_getFunctionProcessor(expression_value_type, return_value_type, node, expression_node));
      };
    
      ASTNode& function_image_domain = *function_descriptor.domainMappingNode().children[1];
      ASTNode& function_expression   = *function_descriptor.definitionNode().children[1];
    
      if (function_expression.is_type<language::expression_list>()) {
        Assert(function_image_domain.is_type<language::type_expression>());
        ASTNode& image_domain_node = function_image_domain;
    
        for (size_t i = 0; i < function_expression.children.size(); ++i) {
          add_component_expression(*function_expression.children[i], *image_domain_node.children[i]);
        }
      } else {
        add_component_expression(function_expression, function_image_domain);
      }
    
      node.m_node_processor = std::move(function_processor);
    }