Select Git revision
test_PolynomialP.cpp
-
Emmanuel Labourasse authoredEmmanuel Labourasse authored
Resume.cpp 12.67 KiB
#include <utils/checkpointing/Resume.hpp>
#include <utils/pugs_config.hpp>
#ifdef PUGS_HAS_HDF5
#include <algebra/LinearSolverOptions.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <language/utils/SymbolTable.hpp>
#include <mesh/Connectivity.hpp>
#include <utils/Exceptions.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/HighFivePugsUtils.hpp>
#include <utils/RandomEngine.hpp>
#include <utils/checkpointing/LinearSolverOptionsHFType.hpp>
#include <utils/checkpointing/ParallelCheckerHFType.hpp>
#include <utils/checkpointing/ResumingData.hpp>
#include <utils/checkpointing/ResumingManager.hpp>
#include <iostream>
void
resume()
{
try {
HighFive::SilenceHDF5 m_silence_hdf5{true};
checkpointing::ResumingData::create();
HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly);
HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table");
const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode();
auto p_symbol_table = p_node->m_symbol_table;
ResumingManager& resuming_manager = ResumingManager::getInstance();
resuming_manager.checkpointNumber() = checkpoint.getAttribute("checkpoint_number").read<uint64_t>() + 1;
std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line "
<< rang::fgB::yellow << p_node->begin().line << rang::fg::reset << " [using " << rang::fgB::cyan
<< checkpoint.getAttribute("name").read<std::string>() << rang::fg::reset << "]\n";
{
HighFive::Group random_seed_group = checkpoint.getGroup("singleton/random_seed");
RandomEngine::instance().setRandomSeed(random_seed_group.getAttribute("current_seed").read<uint64_t>());
}
{
HighFive::Group global_variables_group = checkpoint.getGroup("singleton/execution_info");
const size_t run_number = global_variables_group.getAttribute("run_number").read<size_t>();
const double cumulative_elapse_time =
global_variables_group.getAttribute("cumulative_elapse_time").read<double>();
const double cumulative_total_cpu_time =
global_variables_group.getAttribute("cumulative_total_cpu_time").read<double>();
ExecutionStatManager::getInstance().setRunNumber(run_number + 1);
ExecutionStatManager::getInstance().setPreviousCumulativeElapseTime(cumulative_elapse_time);
ExecutionStatManager::getInstance().setPreviousCumulativeTotalCPUTime(cumulative_total_cpu_time);
}
{
HighFive::Group random_seed_group = checkpoint.getGroup("singleton/parallel_checker");
// Ordering is important! Must set mode before changing the tag (changing mode is not allowed if tag!=0)
ParallelChecker::instance().setMode(random_seed_group.getAttribute("mode").read<ParallelChecker::Mode>());
ParallelChecker::instance().setTag(random_seed_group.getAttribute("tag").read<size_t>());
}
{
HighFive::Group linear_solver_options_default_group =
checkpoint.getGroup("singleton/linear_solver_options_default");
LinearSolverOptions& default_options = LinearSolverOptions::default_options;
default_options.epsilon() = linear_solver_options_default_group.getAttribute("epsilon").read<double>();
default_options.maximumIteration() =
linear_solver_options_default_group.getAttribute("maximum_iteration").read<size_t>();
default_options.verbose() = linear_solver_options_default_group.getAttribute("verbose").read<bool>();
default_options.library() = linear_solver_options_default_group.getAttribute("library").read<LSLibrary>();
default_options.method() = linear_solver_options_default_group.getAttribute("method").read<LSMethod>();
default_options.precond() = linear_solver_options_default_group.getAttribute("precond").read<LSPrecond>();
}
checkpointing::ResumingData::instance().readData(checkpoint, p_symbol_table);
bool finished = true;
do {
finished = true;
for (auto symbol_name : saved_symbol_table.listAttributeNames()) {
auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
auto& attribute = p_symbol->attributes();
switch (attribute.dataType()) {
case ASTNodeDataType::bool_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>();
break;
}
case ASTNodeDataType::unsigned_int_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>();
break;
}
case ASTNodeDataType::int_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>();
break;
}
case ASTNodeDataType::double_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>();
break;
}
case ASTNodeDataType::string_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>();
break;
}
case ASTNodeDataType::vector_t: {
switch (attribute.dataType().dimension()) {
case 1: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>();
break;
}
case 2: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>();
break;
}
case 3: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
}
// LCOV_EXCL_STOP
}
break;
}
case ASTNodeDataType::matrix_t: {
// LCOV_EXCL_START
if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) {
throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
}
// LCOV_EXCL_STOP
switch (attribute.dataType().numberOfRows()) {
case 1: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>();
break;
}
case 2: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>();
break;
}
case 3: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
}
// LCOV_EXCL_STOP
}
break;
}
case ASTNodeDataType::tuple_t: {
switch (attribute.dataType().contentType()) {
case ASTNodeDataType::bool_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>();
break;
}
case ASTNodeDataType::unsigned_int_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>();
break;
}
case ASTNodeDataType::int_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>();
break;
}
case ASTNodeDataType::double_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>();
break;
}
case ASTNodeDataType::string_t: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>();
break;
}
case ASTNodeDataType::vector_t: {
switch (attribute.dataType().contentType().dimension()) {
case 1: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>();
break;
}
case 2: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>();
break;
}
case 3: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
}
// LCOV_EXCL_STOP
}
break;
}
case ASTNodeDataType::matrix_t: {
// LCOV_EXCL_START
if (attribute.dataType().contentType().numberOfRows() !=
attribute.dataType().contentType().numberOfColumns()) {
throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
}
// LCOV_EXCL_STOP
switch (attribute.dataType().contentType().numberOfRows()) {
case 1: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>();
break;
}
case 2: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>();
break;
}
case 3: {
attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>();
break;
}
// LCOV_EXCL_START
default: {
throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
}
// LCOV_EXCL_STOP
}
break;
}
// LCOV_EXCL_START
default: {
throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType()));
}
// LCOV_EXCL_STOP
}
break;
}
// LCOV_EXCL_START
default: {
throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType()));
}
// LCOV_EXCL_STOP
}
}
if (saved_symbol_table.exist("embedded")) {
HighFive::Group embedded = saved_symbol_table.getGroup("embedded");
for (auto symbol_name : embedded.listObjectNames()) {
auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
if (p_symbol->attributes().dataType() == ASTNodeDataType::tuple_t) {
HighFive::Group embedded_tuple_group = embedded.getGroup(symbol_name);
const size_t number_of_components = embedded_tuple_group.getNumberObjects();
std::vector<EmbeddedData> embedded_tuple(number_of_components);
for (size_t i_component = 0; i_component < number_of_components; ++i_component) {
embedded_tuple[i_component] =
CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType().contentType(),
p_symbol->name() + "/" + std::to_string(i_component),
saved_symbol_table);
}
p_symbol->attributes().value() = embedded_tuple;
} else {
p_symbol->attributes().value() =
CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType(), p_symbol->name(),
saved_symbol_table);
}
}
}
const bool symbol_table_has_parent = p_symbol_table->hasParentTable();
const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table");
Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent));
if (symbol_table_has_parent and saved_symbol_table_has_parent) {
p_symbol_table = p_symbol_table->parentTable();
saved_symbol_table = saved_symbol_table.getGroup("symbol table");
finished = false;
}
} while (not finished);
checkpointing::ResumingData::destroy();
}
// LCOV_EXCL_START
catch (HighFive::Exception& e) {
throw NormalError(e.what());
}
// LCOV_EXCL_STOP
}
#else // PUGS_HAS_HDF5
#include <utils/Exceptions.hpp>
void
resume()
{
throw NormalError("checkpoint/resume mechanism requires HDF5");
}
#endif // PUGS_HAS_HDF5