Skip to content
Snippets Groups Projects
Select Git revision
  • f63bc3456a099c763fd830cee6c3cb6a434276b4
  • develop default protected
  • feature/gmsh-reader
  • origin/stage/bouguettaia
  • feature/kinetic-schemes
  • feature/reconstruction
  • feature/local-dt-fsi
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/variational-hydro
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master protected
  • feature/escobar-smoother
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

Resume.cpp

Blame
  • Resume.cpp 12.67 KiB
    #include <utils/checkpointing/Resume.hpp>
    
    #include <utils/pugs_config.hpp>
    
    #ifdef PUGS_HAS_HDF5
    
    #include <algebra/LinearSolverOptions.hpp>
    #include <language/ast/ASTExecutionStack.hpp>
    #include <language/utils/ASTCheckpointsInfo.hpp>
    #include <language/utils/CheckpointResumeRepository.hpp>
    #include <language/utils/SymbolTable.hpp>
    #include <mesh/Connectivity.hpp>
    #include <utils/Exceptions.hpp>
    #include <utils/ExecutionStatManager.hpp>
    #include <utils/HighFivePugsUtils.hpp>
    #include <utils/RandomEngine.hpp>
    #include <utils/checkpointing/LinearSolverOptionsHFType.hpp>
    #include <utils/checkpointing/ParallelCheckerHFType.hpp>
    #include <utils/checkpointing/ResumingData.hpp>
    #include <utils/checkpointing/ResumingManager.hpp>
    
    #include <iostream>
    
    void
    resume()
    {
      try {
        HighFive::SilenceHDF5 m_silence_hdf5{true};
        checkpointing::ResumingData::create();
        HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly);
    
        HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
    
        HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table");
    
        const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode();
        auto p_symbol_table   = p_node->m_symbol_table;
    
        ResumingManager& resuming_manager = ResumingManager::getInstance();
    
        resuming_manager.checkpointNumber() = checkpoint.getAttribute("checkpoint_number").read<uint64_t>() + 1;
    
        std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line "
                  << rang::fgB::yellow << p_node->begin().line << rang::fg::reset << " [using " << rang::fgB::cyan
                  << checkpoint.getAttribute("name").read<std::string>() << rang::fg::reset << "]\n";
    
        {
          HighFive::Group random_seed_group = checkpoint.getGroup("singleton/random_seed");
          RandomEngine::instance().setRandomSeed(random_seed_group.getAttribute("current_seed").read<uint64_t>());
        }
        {
          HighFive::Group global_variables_group = checkpoint.getGroup("singleton/execution_info");
          const size_t run_number                = global_variables_group.getAttribute("run_number").read<size_t>();
          const double cumulative_elapse_time =
            global_variables_group.getAttribute("cumulative_elapse_time").read<double>();
          const double cumulative_total_cpu_time =
            global_variables_group.getAttribute("cumulative_total_cpu_time").read<double>();
    
          ExecutionStatManager::getInstance().setRunNumber(run_number + 1);
          ExecutionStatManager::getInstance().setPreviousCumulativeElapseTime(cumulative_elapse_time);
          ExecutionStatManager::getInstance().setPreviousCumulativeTotalCPUTime(cumulative_total_cpu_time);
        }
        {
          HighFive::Group random_seed_group = checkpoint.getGroup("singleton/parallel_checker");
          // Ordering is important! Must set mode before changing the tag (changing mode is not allowed if tag!=0)
          ParallelChecker::instance().setMode(random_seed_group.getAttribute("mode").read<ParallelChecker::Mode>());
          ParallelChecker::instance().setTag(random_seed_group.getAttribute("tag").read<size_t>());
        }
        {
          HighFive::Group linear_solver_options_default_group =
            checkpoint.getGroup("singleton/linear_solver_options_default");
    
          LinearSolverOptions& default_options = LinearSolverOptions::default_options;
    
          default_options.epsilon() = linear_solver_options_default_group.getAttribute("epsilon").read<double>();
          default_options.maximumIteration() =
            linear_solver_options_default_group.getAttribute("maximum_iteration").read<size_t>();
          default_options.verbose() = linear_solver_options_default_group.getAttribute("verbose").read<bool>();
    
          default_options.library() = linear_solver_options_default_group.getAttribute("library").read<LSLibrary>();
          default_options.method()  = linear_solver_options_default_group.getAttribute("method").read<LSMethod>();
          default_options.precond() = linear_solver_options_default_group.getAttribute("precond").read<LSPrecond>();
        }
    
        checkpointing::ResumingData::instance().readData(checkpoint, p_symbol_table);
    
        bool finished = true;
        do {
          finished = true;
    
          for (auto symbol_name : saved_symbol_table.listAttributeNames()) {
            auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
            auto& attribute        = p_symbol->attributes();
            switch (attribute.dataType()) {
            case ASTNodeDataType::bool_t: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>();
              break;
            }
            case ASTNodeDataType::unsigned_int_t: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>();
              break;
            }
            case ASTNodeDataType::int_t: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>();
              break;
            }
            case ASTNodeDataType::double_t: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>();
              break;
            }
            case ASTNodeDataType::string_t: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>();
              break;
            }
            case ASTNodeDataType::vector_t: {
              switch (attribute.dataType().dimension()) {
              case 1: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>();
                break;
              }
              case 2: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>();
                break;
              }
              case 3: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>();
                break;
              }
                // LCOV_EXCL_START
              default: {
                throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
              }
                // LCOV_EXCL_STOP
              }
              break;
            }
            case ASTNodeDataType::matrix_t: {
              // LCOV_EXCL_START
              if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) {
                throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
              }
              // LCOV_EXCL_STOP
              switch (attribute.dataType().numberOfRows()) {
              case 1: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>();
                break;
              }
              case 2: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>();
                break;
              }
              case 3: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>();
                break;
              }
                // LCOV_EXCL_START
              default: {
                throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
              }
                // LCOV_EXCL_STOP
              }
              break;
            }
            case ASTNodeDataType::tuple_t: {
              switch (attribute.dataType().contentType()) {
              case ASTNodeDataType::bool_t: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>();
                break;
              }
              case ASTNodeDataType::unsigned_int_t: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>();
                break;
              }
              case ASTNodeDataType::int_t: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>();
                break;
              }
              case ASTNodeDataType::double_t: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>();
                break;
              }
              case ASTNodeDataType::string_t: {
                attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>();
                break;
              }
              case ASTNodeDataType::vector_t: {
                switch (attribute.dataType().contentType().dimension()) {
                case 1: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>();
                  break;
                }
                case 2: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>();
                  break;
                }
                case 3: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>();
                  break;
                }
                  // LCOV_EXCL_START
                default: {
                  throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
                }
                  // LCOV_EXCL_STOP
                }
                break;
              }
              case ASTNodeDataType::matrix_t: {
                // LCOV_EXCL_START
                if (attribute.dataType().contentType().numberOfRows() !=
                    attribute.dataType().contentType().numberOfColumns()) {
                  throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
                }
                // LCOV_EXCL_STOP
                switch (attribute.dataType().contentType().numberOfRows()) {
                case 1: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>();
                  break;
                }
                case 2: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>();
                  break;
                }
                case 3: {
                  attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>();
                  break;
                }
                  // LCOV_EXCL_START
                default: {
                  throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
                }
                  // LCOV_EXCL_STOP
                }
                break;
              }
                // LCOV_EXCL_START
              default: {
                throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType()));
              }
                // LCOV_EXCL_STOP
              }
              break;
            }
              // LCOV_EXCL_START
            default: {
              throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType()));
            }
              // LCOV_EXCL_STOP
            }
          }
    
          if (saved_symbol_table.exist("embedded")) {
            HighFive::Group embedded = saved_symbol_table.getGroup("embedded");
    
            for (auto symbol_name : embedded.listObjectNames()) {
              auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
              if (p_symbol->attributes().dataType() == ASTNodeDataType::tuple_t) {
                HighFive::Group embedded_tuple_group = embedded.getGroup(symbol_name);
                const size_t number_of_components    = embedded_tuple_group.getNumberObjects();
                std::vector<EmbeddedData> embedded_tuple(number_of_components);
    
                for (size_t i_component = 0; i_component < number_of_components; ++i_component) {
                  embedded_tuple[i_component] =
                    CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType().contentType(),
                                                                  p_symbol->name() + "/" + std::to_string(i_component),
                                                                  saved_symbol_table);
                }
                p_symbol->attributes().value() = embedded_tuple;
              } else {
                p_symbol->attributes().value() =
                  CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType(), p_symbol->name(),
                                                                saved_symbol_table);
              }
            }
          }
    
          const bool symbol_table_has_parent       = p_symbol_table->hasParentTable();
          const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table");
    
          Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent));
    
          if (symbol_table_has_parent and saved_symbol_table_has_parent) {
            p_symbol_table     = p_symbol_table->parentTable();
            saved_symbol_table = saved_symbol_table.getGroup("symbol table");
    
            finished = false;
          }
    
        } while (not finished);
    
        checkpointing::ResumingData::destroy();
      }
      // LCOV_EXCL_START
      catch (HighFive::Exception& e) {
        throw NormalError(e.what());
      }
      // LCOV_EXCL_STOP
    }
    
    #else   // PUGS_HAS_HDF5
    
    #include <utils/Exceptions.hpp>
    
    void
    resume()
    {
      throw NormalError("checkpoint/resume mechanism requires HDF5");
    }
    
    #endif   // PUGS_HAS_HDF5