Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • develop
  • feature/local-dt-fsi
  • feature/kinetic-schemes
  • origin/stage/bouguettaia
  • feature/variational-hydro
  • save_clemence
  • feature/reconstruction
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master
  • feature/escobar-smoother
  • feature/hypoelasticity-clean
  • feature/hypoelasticity
  • feature/Navier-Stokes
  • feature/Nodal_diffusion
  • feature/explicit-gp-cfl
  • Nodal_diffusion
  • feature/discontinuous-galerkin
  • test/voronoi1d
  • navier-stokes
  • Kidder
  • v0
  • v0.0.1
  • v0.0.2
  • v0.0.3
  • v0.0.4
  • v0.1.0
  • v0.2.0
  • v0.3.0
  • v0.4.0
  • v0.4.1
  • v0.5.0
41 results

Target

Select target project
  • code / pugs
1 result
Select Git revision
  • develop
  • feature/local-dt-fsi
  • feature/kinetic-schemes
  • origin/stage/bouguettaia
  • feature/variational-hydro
  • save_clemence
  • feature/reconstruction
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master
  • feature/escobar-smoother
  • feature/hypoelasticity-clean
  • feature/hypoelasticity
  • feature/Navier-Stokes
  • feature/Nodal_diffusion
  • feature/explicit-gp-cfl
  • Nodal_diffusion
  • feature/discontinuous-galerkin
  • test/voronoi1d
  • navier-stokes
  • Kidder
  • v0
  • v0.0.1
  • v0.0.2
  • v0.0.3
  • v0.0.4
  • v0.1.0
  • v0.2.0
  • v0.3.0
  • v0.4.0
  • v0.4.1
  • v0.5.0
41 results
Show changes

Commits on Source 2

19 files
+ 235
645
Compare changes
  • Side-by-side
  • Inline

Files

+8 −5
Original line number Original line Diff line number Diff line
@@ -4,7 +4,6 @@
#include <language/ast/ASTNodeDataTypeFlattener.hpp>
#include <language/ast/ASTNodeDataTypeFlattener.hpp>
#include <language/node_processor/BuiltinFunctionProcessor.hpp>
#include <language/node_processor/BuiltinFunctionProcessor.hpp>
#include <language/utils/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/BuiltinFunctionEmbedderUtils.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/SymbolTable.hpp>
#include <language/utils/SymbolTable.hpp>


@@ -495,8 +494,7 @@ ASTNodeBuiltinFunctionExpressionBuilder::_buildArgumentProcessors(
  if (arguments_number != parameters_number) {
  if (arguments_number != parameters_number) {
    std::ostringstream error_message;
    std::ostringstream error_message;
    error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
    error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
                  << rang::style::reset << rang::style::bold << ", provided " << rang::fgB::yellow << arguments_number
                  << rang::style::reset << ", provided " << rang::fgB::yellow << arguments_number << rang::style::reset;
                  << rang::style::reset;
    throw ParseError(error_message.str(), argument_nodes.begin());
    throw ParseError(error_message.str(), argument_nodes.begin());
  }
  }


@@ -507,9 +505,14 @@ ASTNodeBuiltinFunctionExpressionBuilder::_buildArgumentProcessors(


ASTNodeBuiltinFunctionExpressionBuilder::ASTNodeBuiltinFunctionExpressionBuilder(ASTNode& node)
ASTNodeBuiltinFunctionExpressionBuilder::ASTNodeBuiltinFunctionExpressionBuilder(ASTNode& node)
{
{
  Assert(node.children[0]->m_data_type == ASTNodeDataType::builtin_function_t);
  auto [i_function_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin());
  Assert(found);
  Assert(i_function_symbol->attributes().dataType() == ASTNodeDataType::builtin_function_t);


  std::shared_ptr builtin_function_embedder = getBuiltinFunctionEmbedder(node);
  uint64_t builtin_function_id = std::get<uint64_t>(i_function_symbol->attributes().value());

  auto& builtin_function_embedder_table     = node.m_symbol_table->builtinFunctionEmbedderTable();
  std::shared_ptr builtin_function_embedder = builtin_function_embedder_table[builtin_function_id];


  std::vector<ASTNodeDataType> builtin_function_parameter_type_list =
  std::vector<ASTNodeDataType> builtin_function_parameter_type_list =
    builtin_function_embedder->getParameterDataTypes();
    builtin_function_embedder->getParameterDataTypes();
+11 −9
Original line number Original line Diff line number Diff line
@@ -3,7 +3,6 @@
#include <language/PEGGrammar.hpp>
#include <language/PEGGrammar.hpp>
#include <language/utils/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/BuiltinFunctionEmbedderUtils.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/SymbolTable.hpp>
#include <language/utils/SymbolTable.hpp>
@@ -285,13 +284,8 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;
        std::shared_ptr<SymbolTable>& symbol_table = n.m_symbol_table;


        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        if (found) {
        Assert(found);
        n.m_data_type = i_symbol->attributes().dataType();
        n.m_data_type = i_symbol->attributes().dataType();
        } else if (symbol_table->has(n.string(), n.begin())) {
          n.m_data_type = ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>();
        } else {
          throw UnexpectedError("could not find symbol " + n.string());
        }
      }
      }
    }
    }
    for (auto& child : n.children) {
    for (auto& child : n.children) {
@@ -528,7 +522,15 @@ ASTNodeDataTypeBuilder::_buildNodeDataTypes(ASTNode& n) const


        n.m_data_type = image_domain_node.m_data_type.contentType();
        n.m_data_type = image_domain_node.m_data_type.contentType();
      } else if (n.children[0]->m_data_type == ASTNodeDataType::builtin_function_t) {
      } else if (n.children[0]->m_data_type == ASTNodeDataType::builtin_function_t) {
        auto builtin_function_embedder = getBuiltinFunctionEmbedder(n);
        const std::string builtin_function_name = n.children[0]->string();
        auto& symbol_table                      = *n.m_symbol_table;

        auto [i_symbol, success] = symbol_table.find(builtin_function_name, n.begin());
        Assert(success);

        uint64_t builtin_function_id   = std::get<uint64_t>(i_symbol->attributes().value());
        auto builtin_function_embedder = symbol_table.builtinFunctionEmbedderTable()[builtin_function_id];
        Assert(builtin_function_embedder);


        n.m_data_type = builtin_function_embedder->getReturnDataType();
        n.m_data_type = builtin_function_embedder->getReturnDataType();
      } else {
      } else {
+5 −10
Original line number Original line Diff line number Diff line
@@ -68,18 +68,13 @@ ASTNodeExpressionBuilder::_buildExpression(ASTNode& n)
    n.m_node_processor = std::make_unique<FakeProcessor>();
    n.m_node_processor = std::make_unique<FakeProcessor>();


  } else if (n.is_type<language::name>()) {
  } else if (n.is_type<language::name>()) {
    if (n.m_data_type == ASTNodeDataType::builtin_function_t) {
      n.m_node_processor = std::make_unique<FakeProcessor>();
    } else {
    // Dealing with contexts
    // Dealing with contexts
    auto [i_symbol, success] = n.m_symbol_table->find(n.string(), n.begin());
    auto [i_symbol, success] = n.m_symbol_table->find(n.string(), n.begin());
      Assert(success, "could not find symbol");
    if (i_symbol->attributes().hasLocalContext()) {
    if (i_symbol->attributes().hasLocalContext()) {
      n.m_node_processor = std::make_unique<LocalNameProcessor>(n);
      n.m_node_processor = std::make_unique<LocalNameProcessor>(n);
    } else {
    } else {
      n.m_node_processor = std::make_unique<NameProcessor>(n);
      n.m_node_processor = std::make_unique<NameProcessor>(n);
    }
    }
    }
  } else if (n.is_type<language::unary_minus>() or n.is_type<language::unary_not>()) {
  } else if (n.is_type<language::unary_minus>() or n.is_type<language::unary_not>()) {
    ASTNodeUnaryOperatorExpressionBuilder{n};
    ASTNodeUnaryOperatorExpressionBuilder{n};


+4 −1
Original line number Original line Diff line number Diff line
@@ -7,7 +7,10 @@


ASTNodeFunctionEvaluationExpressionBuilder::ASTNodeFunctionEvaluationExpressionBuilder(ASTNode& node)
ASTNodeFunctionEvaluationExpressionBuilder::ASTNodeFunctionEvaluationExpressionBuilder(ASTNode& node)
{
{
  switch (node.children[0]->m_data_type) {
  auto [i_function_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin());
  Assert(found);

  switch (i_function_symbol->attributes().dataType()) {
  case ASTNodeDataType::function_t: {
  case ASTNodeDataType::function_t: {
    ASTNodeFunctionExpressionBuilder{node};
    ASTNodeFunctionExpressionBuilder{node};
    break;
    break;
+1 −2
Original line number Original line Diff line number Diff line
@@ -233,8 +233,7 @@ ASTNodeFunctionExpressionBuilder::_buildArgumentConverter(FunctionDescriptor& fu
  if (arguments_number != parameters_number) {
  if (arguments_number != parameters_number) {
    std::ostringstream error_message;
    std::ostringstream error_message;
    error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
    error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
                  << rang::style::reset << rang::style::bold << ", provided " << rang::fgB::yellow << arguments_number
                  << rang::style::reset << ", provided " << rang::fgB::yellow << arguments_number << rang::style::reset;
                  << rang::style::reset;
    throw ParseError(error_message.str(), argument_nodes.begin());
    throw ParseError(error_message.str(), argument_nodes.begin());
  }
  }


+2 −3
Original line number Original line Diff line number Diff line
@@ -107,9 +107,8 @@ ASTSymbolInitializationChecker::_checkSymbolInitialization(ASTNode& node)
    }
    }
  } else if (node.is_type<language::name>()) {
  } else if (node.is_type<language::name>()) {
    auto [i_symbol, found] = node.m_symbol_table->find(node.string(), node.begin());
    auto [i_symbol, found] = node.m_symbol_table->find(node.string(), node.begin());
    Assert(node.m_symbol_table->has(node.string(), node.begin()),
    Assert(found, "unexpected error, should have been detected through declaration checking");
           "unexpected error, should have been detected through declaration checking");
    if (not i_symbol->attributes().isInitialized()) {
    if (found and not i_symbol->attributes().isInitialized()) {
      std::ostringstream error_message;
      std::ostringstream error_message;
      error_message << "uninitialized symbol '" << rang::fg::red << node.string() << rang::fg::reset << '\'';
      error_message << "uninitialized symbol '" << rang::fg::red << node.string() << rang::fg::reset << '\'';
      throw ParseError(error_message.str(), std::vector{node.begin()});
      throw ParseError(error_message.str(), std::vector{node.begin()});
+4 −25
Original line number Original line Diff line number Diff line
@@ -22,19 +22,11 @@ ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable>


    n.m_symbol_table          = local_symbol_table;
    n.m_symbol_table          = local_symbol_table;
    const std::string& symbol = n.children[0]->string();
    const std::string& symbol = n.children[0]->string();

    if (symbol_table->getBuiltinFunctionSymbolList(symbol, n.children[0]->begin()).size() > 0) {
      std::ostringstream error_message;
      error_message << "symbol '" << rang::fg::red << symbol << rang::fg::reset
                    << "' already denotes a builtin function!";
      throw ParseError(error_message.str(), std::vector{n.children[0]->begin()});
    }

    auto [i_symbol, success]  = symbol_table->add(symbol, n.children[0]->begin());
    auto [i_symbol, success]  = symbol_table->add(symbol, n.children[0]->begin());
    if (not success) {
    if (not success) {
      std::ostringstream error_message;
      std::ostringstream error_message;
      error_message << "symbol '" << rang::fg::red << symbol << rang::fg::reset << "' was already defined!";
      error_message << "symbol '" << rang::fg::red << symbol << rang::fg::reset << "' was already defined!";
      throw ParseError(error_message.str(), std::vector{n.children[0]->begin()});
      throw ParseError(error_message.str(), std::vector{n.begin()});
    }
    }


    for (auto& child : n.children) {
    for (auto& child : n.children) {
@@ -50,13 +42,6 @@ ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable>
    if (n.has_content()) {
    if (n.has_content()) {
      if (n.is_type<language::var_declaration>()) {
      if (n.is_type<language::var_declaration>()) {
        auto register_symbol = [&](const ASTNode& argument_node) {
        auto register_symbol = [&](const ASTNode& argument_node) {
          if (symbol_table->getBuiltinFunctionSymbolList(argument_node.string(), argument_node.begin()).size() > 0) {
            std::ostringstream error_message;
            error_message << "symbol '" << rang::fg::red << argument_node.string() << rang::fg::reset
                          << "' already denotes a builtin function!";
            throw ParseError(error_message.str(), std::vector{argument_node.begin()});
          }

          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          if (not success) {
          if (not success) {
            std::ostringstream error_message;
            std::ostringstream error_message;
@@ -76,13 +61,6 @@ ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable>
        }
        }
      } else if (n.is_type<language::function_definition>()) {
      } else if (n.is_type<language::function_definition>()) {
        auto register_and_initialize_symbol = [&](const ASTNode& argument_node) {
        auto register_and_initialize_symbol = [&](const ASTNode& argument_node) {
          if (symbol_table->getBuiltinFunctionSymbolList(argument_node.string(), argument_node.begin()).size() > 0) {
            std::ostringstream error_message;
            error_message << "symbol '" << rang::fg::red << argument_node.string() << rang::fg::reset
                          << "' already denotes a builtin function!";
            throw ParseError(error_message.str(), std::vector{argument_node.begin()});
          }

          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          auto [i_symbol, success] = symbol_table->add(argument_node.string(), argument_node.begin());
          if (not success) {
          if (not success) {
            std::ostringstream error_message;
            std::ostringstream error_message;
@@ -103,7 +81,8 @@ ASTSymbolTableBuilder::buildSymbolTable(ASTNode& n, std::shared_ptr<SymbolTable>
          }
          }
        }
        }
      } else if (n.is_type<language::name>()) {
      } else if (n.is_type<language::name>()) {
        if (not symbol_table->has(n.string(), n.begin())) {
        auto [i_symbol, found] = symbol_table->find(n.string(), n.begin());
        if (not found) {
          std::ostringstream error_message;
          std::ostringstream error_message;
          error_message << "undefined symbol '" << rang::fg::red << n.string() << rang::fg::reset << '\'';
          error_message << "undefined symbol '" << rang::fg::red << n.string() << rang::fg::reset << '\'';
          throw ParseError(error_message.str(), std::vector{n.begin()});
          throw ParseError(error_message.str(), std::vector{n.begin()});
+1 −47
Original line number Original line Diff line number Diff line
@@ -12,54 +12,8 @@ void
BuiltinModule::_addBuiltinFunction(const std::string& name,
BuiltinModule::_addBuiltinFunction(const std::string& name,
                                   std::shared_ptr<IBuiltinFunctionEmbedder> builtin_function_embedder)
                                   std::shared_ptr<IBuiltinFunctionEmbedder> builtin_function_embedder)
{
{
  auto is_keyword = [](const std::string& s) -> bool {
    if (s.size() == 0) {
      return false;
    } else {
      if (not(std::isalpha(s[0]) or s[0] == '_')) {
        return false;
      }
      for (size_t i = 1; i < s.size(); ++i) {
        if (not(std::isalnum(s[0]) or s[0] == '_')) {
          return false;
        }
      }
    }

    return true;
  };

  if (not is_keyword(name)) {
    std::ostringstream os;
    os << "while defining module " << this->name() << " invalid builtin function name: '" << name << "'\n";
    throw UnexpectedError(os.str());
  }

  auto parameter_data_type_list = builtin_function_embedder->getParameterDataTypes();

  std::string mangled_name = [&] {
    std::ostringstream os;
    os << name << '(';
    switch (parameter_data_type_list.size()) {
    case 0: {
      break;
    }
    case 1: {
      os << dataTypeName(parameter_data_type_list[0]);
      break;
    }
    default:
      os << dataTypeName(parameter_data_type_list[0]);
      for (size_t i = 1; i < parameter_data_type_list.size(); ++i) {
        os << ',' << dataTypeName(parameter_data_type_list[i]);
      }
    }
    os << ')';
    return os.str();
  }();

  auto [i_builtin_function, success] =
  auto [i_builtin_function, success] =
    m_name_builtin_function_map.insert(std::make_pair(mangled_name, builtin_function_embedder));
    m_name_builtin_function_map.insert(std::make_pair(name, builtin_function_embedder));
  if (not success) {
  if (not success) {
    throw NormalError("builtin-function '" + name + "' cannot be added!\n");
    throw NormalError("builtin-function '" + name + "' cannot be added!\n");
  }
  }
+2 −34
Original line number Original line Diff line number Diff line
@@ -20,29 +20,6 @@
void
void
ModuleRepository::_subscribe(std::unique_ptr<IModule> m)
ModuleRepository::_subscribe(std::unique_ptr<IModule> m)
{
{
  auto is_keyword = [](const std::string& s) -> bool {
    if (s.size() == 0) {
      return false;
    } else {
      if (not(std::isalpha(s[0]) or s[0] == '_')) {
        return false;
      }
      for (size_t i = 1; i < s.size(); ++i) {
        if (not(std::isalnum(s[0]) or s[0] == '_')) {
          return false;
        }
      }
    }

    return true;
  };

  if (not is_keyword(std::string{m->name()})) {
    std::ostringstream os;
    os << "cannot subscribe module with invalid name: '" << m->name() << "'\n";
    throw UnexpectedError(os.str());
  }

  auto [i_module, success] = m_module_set.emplace(m->name(), std::move(m));
  auto [i_module, success] = m_module_set.emplace(m->name(), std::move(m));
  Assert(success, "module has already been subscribed");
  Assert(success, "module has already been subscribed");
}
}
@@ -170,15 +147,6 @@ ModuleRepository::registerOperators(const std::string& module_name)
std::string
std::string
ModuleRepository::getModuleInfo(const std::string& module_name) const
ModuleRepository::getModuleInfo(const std::string& module_name) const
{
{
  auto demangleBuiltinFunction = [](const std::string& mangled_name) -> std::string {
    size_t i = 0;
    for (; i < mangled_name.size(); ++i) {
      if (mangled_name[i] == '(')
        break;
    }
    return mangled_name.substr(0, i);
  };

  std::stringstream os;
  std::stringstream os;
  auto i_module = m_module_set.find(module_name);
  auto i_module = m_module_set.find(module_name);
  if (i_module != m_module_set.end()) {
  if (i_module != m_module_set.end()) {
@@ -187,8 +155,8 @@ ModuleRepository::getModuleInfo(const std::string& module_name) const
    const auto& builtin_function_map = i_module->second->getNameBuiltinFunctionMap();
    const auto& builtin_function_map = i_module->second->getNameBuiltinFunctionMap();
    if (builtin_function_map.size() > 0) {
    if (builtin_function_map.size() > 0) {
      os << "  functions\n";
      os << "  functions\n";
      for (auto& [mangled_name, function] : builtin_function_map) {
      for (auto& [name, function] : builtin_function_map) {
        os << "    " << rang::fgB::green << demangleBuiltinFunction(mangled_name) << rang::style::reset << ": ";
        os << "    " << rang::fgB::green << name << rang::style::reset << ": ";
        os << dataTypeName(function->getParameterDataTypes());
        os << dataTypeName(function->getParameterDataTypes());
        os << rang::fgB::yellow << " -> " << rang::style::reset;
        os << rang::fgB::yellow << " -> " << rang::style::reset;
        os << dataTypeName(function->getReturnDataType()) << '\n';
        os << dataTypeName(function->getReturnDataType()) << '\n';
+0 −209
Original line number Original line Diff line number Diff line
#include <language/utils/BuiltinFunctionEmbedderUtils.hpp>

#include <language/PEGGrammar.hpp>
#include <language/ast/ASTNode.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/SymbolTable.hpp>

void
flattenDataTypes(const ASTNodeDataType& data_type, std::vector<ASTNodeDataType>& arg_type_list)
{
  if ((data_type == ASTNodeDataType::list_t) and (*data_type.contentTypeList()[0] == ASTNodeDataType::typename_t)) {
    for (auto data_type : data_type.contentTypeList()) {
      arg_type_list.push_back(data_type->contentType());
    }
  } else {
    arg_type_list.push_back(data_type);
  }
}

std::shared_ptr<IBuiltinFunctionEmbedder>
getBuiltinFunctionEmbedder(ASTNode& n)
{
  const std::string builtin_function_name = n.children[0]->string();
  auto& symbol_table                      = *n.m_symbol_table;

  auto& args_node = *n.children[1];

  std::vector<ASTNodeDataType> arg_type_list;
  if (args_node.is_type<language::function_argument_list>()) {
    for (auto& arg : args_node.children) {
      flattenDataTypes(arg->m_data_type, arg_type_list);
    }
  } else {
    flattenDataTypes(args_node.m_data_type, arg_type_list);
  }

  std::ostringstream mangled_name;
  mangled_name << builtin_function_name;
  if (size(arg_type_list) == 0) {
    mangled_name << "()";
  } else if (size(arg_type_list) == 1) {
    mangled_name << '(' << dataTypeName(arg_type_list[0]) << ')';
  } else {
    mangled_name << dataTypeName(arg_type_list);
  }

  std::vector builtin_function_candidate_list =
    symbol_table.getBuiltinFunctionSymbolList(builtin_function_name, n.begin());

  auto is_castable_to_vector = [](const ASTNodeDataType& arg_type, const ASTNodeDataType& target_type) {
    bool is_castable = true;
    if (target_type.dimension() > 1) {
      switch (arg_type) {
      case ASTNodeDataType::int_t: {
        break;
      }
      case ASTNodeDataType::list_t: {
        if (arg_type.contentTypeList().size() != target_type.dimension()) {
          is_castable = false;
          break;
        }
        for (auto list_arg : arg_type.contentTypeList()) {
          is_castable &= isNaturalConversion(*list_arg, ASTNodeDataType::build<ASTNodeDataType::double_t>());
        }
        break;
      }
      default: {
        is_castable &= false;
      }
      }
    } else {
      is_castable &= isNaturalConversion(arg_type, ASTNodeDataType::build<ASTNodeDataType::double_t>());
    }
    return is_castable;
  };

  auto is_castable_to_matrix = [](const ASTNodeDataType& arg_type, const ASTNodeDataType& target_type) {
    bool is_castable = true;
    if (target_type.nbRows() > 1) {
      switch (arg_type) {
      case ASTNodeDataType::int_t: {
        break;
      }
      case ASTNodeDataType::list_t: {
        if (arg_type.contentTypeList().size() != target_type.nbRows() * target_type.nbColumns()) {
          is_castable = false;
          break;
        }
        for (auto list_arg : arg_type.contentTypeList()) {
          is_castable &= isNaturalConversion(*list_arg, ASTNodeDataType::build<ASTNodeDataType::double_t>());
        }
        break;
      }
      default: {
        is_castable &= false;
      }
      }
    } else {
      is_castable &= isNaturalConversion(arg_type, ASTNodeDataType::build<ASTNodeDataType::double_t>());
    }
    return is_castable;
  };

  std::vector<uint64_t> callable_id_list;
  for (auto candidate : builtin_function_candidate_list) {
    uint64_t builtin_function_id = std::get<uint64_t>(candidate.attributes().value());

    auto& builtin_function_embedder_table     = n.m_symbol_table->builtinFunctionEmbedderTable();
    std::shared_ptr builtin_function_embedder = builtin_function_embedder_table[builtin_function_id];

    if (builtin_function_embedder->numberOfParameters() == arg_type_list.size()) {
      std::vector<ASTNodeDataType> builtin_function_parameter_type_list =
        builtin_function_embedder->getParameterDataTypes();
      bool is_castable = true;
      for (size_t i_arg = 0; i_arg < arg_type_list.size(); ++i_arg) {
        const ASTNodeDataType& target_type = builtin_function_parameter_type_list[i_arg];
        std::vector<ASTNodeDataType> sub_arg_type_list;
        if (arg_type_list[i_arg] == ASTNodeDataType::list_t) {
          for (auto& arg_type : arg_type_list[i_arg].contentTypeList()) {
            sub_arg_type_list.push_back(*arg_type);
          }
        } else {
          sub_arg_type_list.push_back(arg_type_list[i_arg]);
        }
        for (auto arg_type : sub_arg_type_list) {
          if (not isNaturalConversion(arg_type, target_type)) {
            switch (target_type) {
            case ASTNodeDataType::vector_t: {
              is_castable &= is_castable_to_vector(arg_type, target_type);
              break;
            }
            case ASTNodeDataType::matrix_t: {
              is_castable &= is_castable_to_matrix(arg_type, target_type);
              break;
            }
            case ASTNodeDataType::tuple_t: {
              ASTNodeDataType tuple_content_type = target_type.contentType();
              if (not isNaturalConversion(arg_type, tuple_content_type)) {
                switch (tuple_content_type) {
                case ASTNodeDataType::vector_t: {
                  is_castable &= is_castable_to_vector(arg_type, tuple_content_type);
                  break;
                }
                case ASTNodeDataType::matrix_t: {
                  is_castable &= is_castable_to_matrix(arg_type, tuple_content_type);
                  break;
                }
                default:
                  is_castable &= false;
                }
              }
              break;
            }
            default:
              is_castable &= false;
            }
          }
        }
      }
      if (is_castable) {
        callable_id_list.push_back(builtin_function_id);
      }
    }
  }

  uint64_t builtin_function_id = [&] {
    switch (size(callable_id_list)) {
    case 0: {
      std::ostringstream error_msg;
      error_msg << "no matching function to call " << rang::fgB::red << builtin_function_name << rang::style::reset
                << rang::style::bold << ": " << rang::fgB::yellow << dataTypeName(arg_type_list) << rang::style::reset
                << rang::style::bold << "\nnote: candidates are";

      for (auto candidate : builtin_function_candidate_list) {
        uint64_t builtin_function_id = std::get<uint64_t>(candidate.attributes().value());

        auto& builtin_function_embedder_table     = n.m_symbol_table->builtinFunctionEmbedderTable();
        std::shared_ptr builtin_function_embedder = builtin_function_embedder_table[builtin_function_id];

        error_msg << "\n " << builtin_function_name << ": "
                  << dataTypeName(builtin_function_embedder->getParameterDataTypes()) << " -> "
                  << dataTypeName(builtin_function_embedder->getReturnDataType());
      }

      throw ParseError(error_msg.str(), n.begin());
    }
    case 1: {
      return callable_id_list[0];
    }
    default: {
      std::ostringstream error_msg;
      error_msg << "ambiguous function to call " << mangled_name.str() << "\nnote: candidates are";

      auto& builtin_function_embedder_table = n.m_symbol_table->builtinFunctionEmbedderTable();
      for (auto callable_id : callable_id_list) {
        std::shared_ptr builtin_function_embedder = builtin_function_embedder_table[callable_id];

        error_msg << "\n " << builtin_function_name << ": "
                  << dataTypeName(builtin_function_embedder->getParameterDataTypes()) << " -> "
                  << dataTypeName(builtin_function_embedder->getReturnDataType());
      }
      throw ParseError(error_msg.str(), n.begin());
    }
    }
  }();

  return n.m_symbol_table->builtinFunctionEmbedderTable()[builtin_function_id];
}
+0 −11
Original line number Original line Diff line number Diff line
#ifndef BUILTIN_FUNCTION_EMBEDDER_UTILS_HPP
#define BUILTIN_FUNCTION_EMBEDDER_UTILS_HPP

class IBuiltinFunctionEmbedder;
class ASTNode;

#include <memory>

std::shared_ptr<IBuiltinFunctionEmbedder> getBuiltinFunctionEmbedder(ASTNode& n);

#endif   // BUILTIN_FUNCTION_EMBEDDER_UTILS_HPP
+0 −1
Original line number Original line Diff line number Diff line
@@ -20,7 +20,6 @@ add_library(PugsLanguageUtils
  BinaryOperatorRegisterForRnxn.cpp
  BinaryOperatorRegisterForRnxn.cpp
  BinaryOperatorRegisterForString.cpp
  BinaryOperatorRegisterForString.cpp
  BinaryOperatorRegisterForZ.cpp
  BinaryOperatorRegisterForZ.cpp
  BuiltinFunctionEmbedderUtils.cpp
  DataVariant.cpp
  DataVariant.cpp
  EmbeddedData.cpp
  EmbeddedData.cpp
  EmbeddedIDiscreteFunctionOperators.cpp
  EmbeddedIDiscreteFunctionOperators.cpp
+2 −58
Original line number Original line Diff line number Diff line
@@ -279,55 +279,6 @@ class SymbolTable
    }
    }
  }
  }


  std::vector<Symbol>
  getBuiltinFunctionSymbolList(const std::string& symbol, const TAO_PEGTL_NAMESPACE::position& use_position)
  {
    std::vector<Symbol> builtin_function_symbol_list;

    for (auto i_stored_symbol : m_symbol_list) {
      if (use_position.byte < i_stored_symbol.attributes().position().byte)
        continue;

      // Symbol must be defined before the call
      std::string_view stored_symbol_name = i_stored_symbol.name();
      if ((stored_symbol_name.size() > symbol.size()) and (stored_symbol_name[symbol.size()] == '(')) {
        if (stored_symbol_name.substr(0, symbol.size()) == symbol) {
          builtin_function_symbol_list.push_back(i_stored_symbol);
        }
      }
    }

    if (m_parent_table) {
      return m_parent_table->getBuiltinFunctionSymbolList(symbol, use_position);
    } else {
      return builtin_function_symbol_list;
    }
  }

  bool
  has(const std::string& symbol, const TAO_PEGTL_NAMESPACE::position& use_position)
  {
    for (auto i_stored_symbol : m_symbol_list) {
      if (use_position.byte < i_stored_symbol.attributes().position().byte)
        continue;

      // Symbol must be defined before the call
      std::string_view stored_symbol_name = i_stored_symbol.name();
      if ((stored_symbol_name.size() == symbol.size()) or
          (stored_symbol_name.size() > symbol.size() and (stored_symbol_name[symbol.size()] == '('))) {
        if (stored_symbol_name.substr(0, symbol.size()) == symbol) {
          return true;
        }
      }
    }

    if (m_parent_table) {
      return m_parent_table->has(symbol, use_position);
    } else {
      return false;
    }
  }

  auto
  auto
  find(const std::string& symbol, const TAO_PEGTL_NAMESPACE::position& use_position)
  find(const std::string& symbol, const TAO_PEGTL_NAMESPACE::position& use_position)
  {
  {
@@ -355,15 +306,8 @@ class SymbolTable
  add(const std::string& symbol_name, const TAO_PEGTL_NAMESPACE::position& symbol_position)
  add(const std::string& symbol_name, const TAO_PEGTL_NAMESPACE::position& symbol_position)
  {
  {
    for (auto i_stored_symbol = m_symbol_list.begin(); i_stored_symbol != m_symbol_list.end(); ++i_stored_symbol) {
    for (auto i_stored_symbol = m_symbol_list.begin(); i_stored_symbol != m_symbol_list.end(); ++i_stored_symbol) {
      std::string_view stored_symbol_name = i_stored_symbol->name();
      if (i_stored_symbol->name() == symbol_name) {
      if (stored_symbol_name.size() == symbol_name.size()) {
        if (stored_symbol_name == symbol_name) {
        return std::make_pair(i_stored_symbol, false);
        return std::make_pair(i_stored_symbol, false);
        } else if (stored_symbol_name.size() > symbol_name.size() and (stored_symbol_name[symbol_name.size()] == '(')) {
          if (stored_symbol_name.substr(0, symbol_name.size()) == symbol_name) {
            return std::make_pair(i_stored_symbol, false);
          }
        }
      }
      }
    }
    }


+2 −2
Original line number Original line Diff line number Diff line
@@ -102,7 +102,7 @@ import unknown_module;
      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, ParseError);
      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, ParseError);
    }
    }


    SECTION("symbol already defined (same builtin function)")
    SECTION("symbol already defined")
    {
    {
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
@@ -111,7 +111,7 @@ import math;
      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
      TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"};
      auto ast = ASTBuilder::build(input);
      auto ast = ASTBuilder::build(input);


      ast->m_symbol_table->add("sin(R)", ast->begin());
      ast->m_symbol_table->add("sin", ast->begin());


      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, ParseError);
      REQUIRE_THROWS_AS(ASTModulesImporter{*ast}, ParseError);
    }
    }
+92 −121

File changed.

Preview size limit exceeded, changes collapsed.

+1 −1
Original line number Original line Diff line number Diff line
@@ -51,7 +51,7 @@ sin(3);
    std::string_view result = R"(
    std::string_view result = R"(
(root:ASTNodeListProcessor)
(root:ASTNodeListProcessor)
 `-(language::function_evaluation:BuiltinFunctionProcessor)
 `-(language::function_evaluation:BuiltinFunctionProcessor)
     +-(language::name:sin:FakeProcessor)
     +-(language::name:sin:NameProcessor)
     `-(language::integer:3:ValueProcessor)
     `-(language::integer:3:ValueProcessor)
)";
)";


+22 −22
Original line number Original line Diff line number Diff line
@@ -81,7 +81,7 @@ TEST_CASE("BuiltinFunctionProcessor", "[language]")


    std::set<std::string> tested_function_set;
    std::set<std::string> tested_function_set;
    {   // sqrt
    {   // sqrt
      tested_function_set.insert("sqrt(R)");
      tested_function_set.insert("sqrt");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = sqrt(4);
let x:R, x = sqrt(4);
@@ -90,7 +90,7 @@ let x:R, x = sqrt(4);
    }
    }


    {   // abs
    {   // abs
      tested_function_set.insert("abs(R)");
      tested_function_set.insert("abs");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = abs(-3.4);
let x:R, x = abs(-3.4);
@@ -99,7 +99,7 @@ let x:R, x = abs(-3.4);
    }
    }


    {   // sin
    {   // sin
      tested_function_set.insert("sin(R)");
      tested_function_set.insert("sin");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = sin(1.3);
let x:R, x = sin(1.3);
@@ -108,7 +108,7 @@ let x:R, x = sin(1.3);
    }
    }


    {   // cos
    {   // cos
      tested_function_set.insert("cos(R)");
      tested_function_set.insert("cos");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = cos(1.3);
let x:R, x = cos(1.3);
@@ -117,7 +117,7 @@ let x:R, x = cos(1.3);
    }
    }


    {   // tan
    {   // tan
      tested_function_set.insert("tan(R)");
      tested_function_set.insert("tan");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = tan(1.3);
let x:R, x = tan(1.3);
@@ -126,7 +126,7 @@ let x:R, x = tan(1.3);
    }
    }


    {   // asin
    {   // asin
      tested_function_set.insert("asin(R)");
      tested_function_set.insert("asin");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = asin(0.7);
let x:R, x = asin(0.7);
@@ -135,7 +135,7 @@ let x:R, x = asin(0.7);
    }
    }


    {   // acos
    {   // acos
      tested_function_set.insert("acos(R)");
      tested_function_set.insert("acos");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = acos(0.7);
let x:R, x = acos(0.7);
@@ -144,7 +144,7 @@ let x:R, x = acos(0.7);
    }
    }


    {   // atan
    {   // atan
      tested_function_set.insert("atan(R)");
      tested_function_set.insert("atan");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = atan(0.7);
let x:R, x = atan(0.7);
@@ -153,7 +153,7 @@ let x:R, x = atan(0.7);
    }
    }


    {   // atan2
    {   // atan2
      tested_function_set.insert("atan2(R,R)");
      tested_function_set.insert("atan2");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = atan2(0.7, 0.4);
let x:R, x = atan2(0.7, 0.4);
@@ -162,7 +162,7 @@ let x:R, x = atan2(0.7, 0.4);
    }
    }


    {   // sinh
    {   // sinh
      tested_function_set.insert("sinh(R)");
      tested_function_set.insert("sinh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = sinh(0.6);
let x:R, x = sinh(0.6);
@@ -171,7 +171,7 @@ let x:R, x = sinh(0.6);
    }
    }


    {   // cosh
    {   // cosh
      tested_function_set.insert("cosh(R)");
      tested_function_set.insert("cosh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = cosh(1.7);
let x:R, x = cosh(1.7);
@@ -180,7 +180,7 @@ let x:R, x = cosh(1.7);
    }
    }


    {   // tanh
    {   // tanh
      tested_function_set.insert("tanh(R)");
      tested_function_set.insert("tanh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = tanh(0.6);
let x:R, x = tanh(0.6);
@@ -189,7 +189,7 @@ let x:R, x = tanh(0.6);
    }
    }


    {   // asinh
    {   // asinh
      tested_function_set.insert("asinh(R)");
      tested_function_set.insert("asinh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = asinh(0.6);
let x:R, x = asinh(0.6);
@@ -198,7 +198,7 @@ let x:R, x = asinh(0.6);
    }
    }


    {   // acosh
    {   // acosh
      tested_function_set.insert("acosh(R)");
      tested_function_set.insert("acosh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = acosh(1.7);
let x:R, x = acosh(1.7);
@@ -207,7 +207,7 @@ let x:R, x = acosh(1.7);
    }
    }


    {   // tanh
    {   // tanh
      tested_function_set.insert("atanh(R)");
      tested_function_set.insert("atanh");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = atanh(0.6);
let x:R, x = atanh(0.6);
@@ -216,7 +216,7 @@ let x:R, x = atanh(0.6);
    }
    }


    {   // exp
    {   // exp
      tested_function_set.insert("exp(R)");
      tested_function_set.insert("exp");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = exp(1.7);
let x:R, x = exp(1.7);
@@ -225,7 +225,7 @@ let x:R, x = exp(1.7);
    }
    }


    {   // log
    {   // log
      tested_function_set.insert("log(R)");
      tested_function_set.insert("log");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = log(1.6);
let x:R, x = log(1.6);
@@ -234,7 +234,7 @@ let x:R, x = log(1.6);
    }
    }


    {   // pow
    {   // pow
      tested_function_set.insert("pow(R,R)");
      tested_function_set.insert("pow");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let x:R, x = pow(1.6, 2.3);
let x:R, x = pow(1.6, 2.3);
@@ -243,7 +243,7 @@ let x:R, x = pow(1.6, 2.3);
    }
    }


    {   // ceil
    {   // ceil
      tested_function_set.insert("ceil(R)");
      tested_function_set.insert("ceil");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let z:Z, z = ceil(-1.2);
let z:Z, z = ceil(-1.2);
@@ -252,7 +252,7 @@ let z:Z, z = ceil(-1.2);
    }
    }


    {   // floor
    {   // floor
      tested_function_set.insert("floor(R)");
      tested_function_set.insert("floor");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let z:Z, z = floor(-1.2);
let z:Z, z = floor(-1.2);
@@ -261,7 +261,7 @@ let z:Z, z = floor(-1.2);
    }
    }


    {   // trunc
    {   // trunc
      tested_function_set.insert("trunc(R)");
      tested_function_set.insert("trunc");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let z:Z, z = trunc(-0.2) + trunc(0.7);
let z:Z, z = trunc(-0.2) + trunc(0.7);
@@ -270,7 +270,7 @@ let z:Z, z = trunc(-0.2) + trunc(0.7);
    }
    }


    {   // round
    {   // round
      tested_function_set.insert("round(R)");
      tested_function_set.insert("round");
      std::string_view data = R"(
      std::string_view data = R"(
import math;
import math;
let z:Z, z = round(-1.2);
let z:Z, z = round(-1.2);