From d54056a4eb95f4c1ad06b93a0b4b101703fc8403 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Sun, 7 Jan 2024 11:36:47 +0100 Subject: [PATCH] Begin checkpointing mechanism - The checkpoint/resume system is functional for basic types - By now, only the last checkpoint can be used --- CMakeLists.txt | 3 + src/language/PugsParser.cpp | 16 +- src/language/ast/ASTExecutionStack.cpp | 16 +- ...STNodeBuiltinFunctionExpressionBuilder.cpp | 2 +- src/language/modules/CMakeLists.txt | 1 + src/language/modules/CoreModule.cpp | 36 ++- src/language/modules/DevUtilsModule.cpp | 4 +- src/language/modules/ModuleRepository.cpp | 1 + .../node_processor/ASTNodeListProcessor.hpp | 20 +- .../node_processor/DoWhileProcessor.hpp | 13 + src/language/node_processor/ForProcessor.hpp | 20 +- src/language/node_processor/IfProcessor.hpp | 57 +++-- .../node_processor/WhileProcessor.hpp | 16 +- src/language/utils/ASTCheckpoint.hpp | 36 +++ src/language/utils/ASTCheckpointsInfo.cpp | 53 ++++ src/language/utils/ASTCheckpointsInfo.hpp | 53 ++++ src/language/utils/ASTExecutionInfo.cpp | 14 +- src/language/utils/ASTExecutionInfo.hpp | 4 +- src/language/utils/CMakeLists.txt | 1 + src/language/utils/SymbolTable.hpp | 41 +++- src/main.cpp | 3 + src/utils/CMakeLists.txt | 2 + src/utils/PugsUtils.cpp | 7 + src/utils/checkpointing/CMakeLists.txt | 18 ++ src/utils/checkpointing/Checkpoint.cpp | 108 +++++++++ src/utils/checkpointing/Checkpoint.hpp | 6 + src/utils/checkpointing/Resume.cpp | 229 ++++++++++++++++++ src/utils/checkpointing/Resume.hpp | 6 + src/utils/checkpointing/ResumingManager.cpp | 44 ++++ src/utils/checkpointing/ResumingManager.hpp | 72 ++++++ src/utils/checkpointing/ResumingUtils.cpp | 27 +++ src/utils/checkpointing/ResumingUtils.hpp | 8 + tests/CMakeLists.txt | 4 +- tests/mpi_test_main.cpp | 5 + tests/test_ASTModulesImporter.cpp | 2 +- tests/test_main.cpp | 5 + 36 files changed, 895 insertions(+), 58 deletions(-) create mode 100644 src/language/utils/ASTCheckpoint.hpp create mode 100644 src/language/utils/ASTCheckpointsInfo.cpp create mode 100644 src/language/utils/ASTCheckpointsInfo.hpp create mode 100644 src/utils/checkpointing/CMakeLists.txt create mode 100644 src/utils/checkpointing/Checkpoint.cpp create mode 100644 src/utils/checkpointing/Checkpoint.hpp create mode 100644 src/utils/checkpointing/Resume.cpp create mode 100644 src/utils/checkpointing/Resume.hpp create mode 100644 src/utils/checkpointing/ResumingManager.cpp create mode 100644 src/utils/checkpointing/ResumingManager.hpp create mode 100644 src/utils/checkpointing/ResumingUtils.cpp create mode 100644 src/utils/checkpointing/ResumingUtils.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 560dd74c0..3fc961a86 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -597,6 +597,7 @@ target_link_libraries( PugsMesh PugsAlgebra PugsAnalysis + PugsCheckpointing PugsDev PugsUtils PugsLanguage @@ -609,6 +610,7 @@ target_link_libraries( PugsUtils PugsOutput PugsLanguageUtils + PugsCheckpointing Kokkos::kokkos ${PETSC_LIBRARIES} ${SLEPC_LIBRARIES} @@ -635,6 +637,7 @@ install(TARGETS PugsMesh PugsAlgebra PugsAnalysis + PugsCheckpointing PugsDev PugsUtils PugsLanguage diff --git a/src/language/PugsParser.cpp b/src/language/PugsParser.cpp index 3e193aef2..b5b446f69 100644 --- a/src/language/PugsParser.cpp +++ b/src/language/PugsParser.cpp @@ -14,9 +14,8 @@ #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolInitializationChecker.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> -#include <language/utils/ASTDotPrinter.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/ASTExecutionInfo.hpp> -#include <language/utils/ASTPrinter.hpp> #include <language/utils/Exit.hpp> #include <language/utils/OperatorRepository.hpp> #include <language/utils/SymbolTable.hpp> @@ -26,6 +25,8 @@ #include <utils/PugsAssert.hpp> #include <utils/PugsUtils.hpp> #include <utils/SignalManager.hpp> +#include <utils/checkpointing/ResumingManager.hpp> +#include <utils/checkpointing/ResumingUtils.hpp> #include <pegtl/contrib/analyze.hpp> #include <pegtl/contrib/parse_tree.hpp> @@ -92,6 +93,8 @@ parser(const std::string& filename) ASTExecutionInfo execution_info{*root_node, module_importer.moduleRepository()}; + ASTCheckpointsInfo checkpoint_info{*root_node}; + ExecutionPolicy exec_all; try { root_node->execute(exec_all); @@ -99,6 +102,7 @@ parser(const std::string& filename) catch (language::Exit& e) { ExecutionStatManager::getInstance().setExitCode(e.code()); } + root_node->m_symbol_table->clearValues(); OperatorRepository::destroy(); @@ -106,8 +110,12 @@ parser(const std::string& filename) std::string file_content; if (parallel::rank() == 0) { - std::ifstream file{filename}; - file_content = std::string{std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; + if (ResumingManager::getInstance().isResuming()) { + file_content = resumingDatafile(filename); + } else { + std::ifstream file{filename}; + file_content = std::string{std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; + } } parallel::broadcast(file_content, 0); diff --git a/src/language/ast/ASTExecutionStack.cpp b/src/language/ast/ASTExecutionStack.cpp index 922e40e32..48add07ec 100644 --- a/src/language/ast/ASTExecutionStack.cpp +++ b/src/language/ast/ASTExecutionStack.cpp @@ -15,13 +15,17 @@ ASTExecutionStack::errorMessageAt(const std::string& message) const auto& stack = ASTExecutionStack::getInstance().m_stack; std::ostringstream error_msg; - auto p = stack[stack.size() - 1]->begin(); - error_msg << rang::style::bold << p.source << ':' << p.line << ':' << p.column << ": " << rang::style::reset - << message << rang::fg::reset << '\n'; + if (stack.size() > 0) { + auto p = stack[stack.size() - 1]->begin(); + error_msg << rang::style::bold << p.source << ':' << p.line << ':' << p.column << ": " << rang::style::reset + << message << rang::fg::reset << '\n'; - if (m_file_input.use_count() > 0) { - error_msg << m_file_input->line_at(p) << '\n' - << std::string(p.column - 1, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; + if (m_file_input.use_count() > 0) { + error_msg << m_file_input->line_at(p) << '\n' + << std::string(p.column - 1, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; + } + } else { + error_msg << message; } return error_msg.str(); diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 169540df2..ad4ced30f 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -281,7 +281,7 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } // LCOV_EXCL_START default: { - throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector"); + throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of matrix"); } // LCOV_EXCL_STOP } diff --git a/src/language/modules/CMakeLists.txt b/src/language/modules/CMakeLists.txt index 5f87dcf34..1ecbe1db5 100644 --- a/src/language/modules/CMakeLists.txt +++ b/src/language/modules/CMakeLists.txt @@ -23,6 +23,7 @@ target_link_libraries( ) add_dependencies( + PugsCheckpointing PugsLanguageModules PugsLanguageAlgorithms PugsUtils diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp index e9be93f3c..72b46b301 100644 --- a/src/language/modules/CoreModule.cpp +++ b/src/language/modules/CoreModule.cpp @@ -33,6 +33,9 @@ #include <utils/Messenger.hpp> #include <utils/PugsUtils.hpp> #include <utils/RandomEngine.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/Resume.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <random> @@ -54,7 +57,7 @@ CoreModule::CoreModule() : BuiltinModule(true) []() -> std::string { const ModuleRepository& repository = - ASTExecutionInfo::current().moduleRepository(); + ASTExecutionInfo::getInstance().moduleRepository(); return repository.getAvailableModules(); } @@ -65,7 +68,7 @@ CoreModule::CoreModule() : BuiltinModule(true) [](const std::string& module_name) -> std::string { const ModuleRepository& repository = - ASTExecutionInfo::current().moduleRepository(); + ASTExecutionInfo::getInstance().moduleRepository(); return repository.getModuleInfo(module_name); } @@ -99,7 +102,7 @@ CoreModule::CoreModule() : BuiltinModule(true) this->_addBuiltinFunction("exit", std::function( [](const int64_t& exit_code) -> void { - const auto& location = ASTBacktrace::getInstance().sourceLocation(); + const auto& location = ASTExecutionStack::getInstance().sourceLocation(); std::cout << "\n** " << rang::fgB::yellow << "exit" << rang::fg::reset << " explicitly called with code " << rang::fgB::cyan << exit_code << rang::fg::reset << "\n from " << rang::style::underline @@ -111,6 +114,33 @@ CoreModule::CoreModule() : BuiltinModule(true) )); + this->_addBuiltinFunction("checkpoint", std::function( + + []() -> void { + if (ResumingManager::getInstance().isResuming()) { + resume(); + ResumingManager::getInstance().setIsResuming(false); + } else { + checkpoint(); + } + } + + )); + + this->_addBuiltinFunction("checkpoint_and_exit", std::function( + + []() -> void { + if (ResumingManager::getInstance().isResuming()) { + resume(); + ResumingManager::getInstance().setIsResuming(false); + } else { + checkpoint(); + throw language::Exit(0); + } + } + + )); + this->_addNameValue("cout", ast_node_data_type_from<std::shared_ptr<const OStream>>, EmbeddedData{std::make_shared<DataHandler<const OStream>>(std::make_shared<OStream>(std::cout))}); diff --git a/src/language/modules/DevUtilsModule.cpp b/src/language/modules/DevUtilsModule.cpp index 7df6e847e..a17382946 100644 --- a/src/language/modules/DevUtilsModule.cpp +++ b/src/language/modules/DevUtilsModule.cpp @@ -39,7 +39,7 @@ DevUtilsModule::DevUtilsModule() this->_addBuiltinFunction("getAST", std::function( []() -> std::string { - const auto& root_node = ASTExecutionInfo::current().rootNode(); + const auto& root_node = ASTExecutionInfo::getInstance().rootNode(); std::ostringstream os; os << ASTPrinter{root_node}; @@ -52,7 +52,7 @@ DevUtilsModule::DevUtilsModule() this->_addBuiltinFunction("saveASTDot", std::function( [](const std::string& dot_filename) -> void { - const auto& root_node = ASTExecutionInfo::current().rootNode(); + const auto& root_node = ASTExecutionInfo::getInstance().rootNode(); std::ofstream fout(dot_filename); diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index 5bdad8bce..e1e53796f 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -108,6 +108,7 @@ ModuleRepository::_populateSymbolTable(const ASTNode& module_node, i_symbol->attributes().setDataType(value_descriptor->type()); i_symbol->attributes().setIsInitialized(); + i_symbol->attributes().setIsModuleVariable(); i_symbol->attributes().value() = value_descriptor->value(); } } diff --git a/src/language/node_processor/ASTNodeListProcessor.hpp b/src/language/node_processor/ASTNodeListProcessor.hpp index 142d33f7d..e9564c9a3 100644 --- a/src/language/node_processor/ASTNodeListProcessor.hpp +++ b/src/language/node_processor/ASTNodeListProcessor.hpp @@ -4,7 +4,10 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class ASTNodeListProcessor final : public INodeProcessor { @@ -21,8 +24,21 @@ class ASTNodeListProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - for (auto& child : m_node.children) { - child->execute(exec_policy); + ResumingManager& resuming_manager = ResumingManager::getInstance(); + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + for (size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + i_child < m_node.children.size(); ++i_child) { + m_node.children[i_child]->execute(exec_policy); + } + } else { + for (auto&& child : m_node.children) { + child->execute(exec_policy); + } } if (not(m_node.is_root() or m_node.is_type<language::for_statement_block>())) diff --git a/src/language/node_processor/DoWhileProcessor.hpp b/src/language/node_processor/DoWhileProcessor.hpp index 593fd4bf1..cbda2ca26 100644 --- a/src/language/node_processor/DoWhileProcessor.hpp +++ b/src/language/node_processor/DoWhileProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class DoWhileProcessor final : public INodeProcessor { @@ -22,7 +25,17 @@ class DoWhileProcessor final : public INodeProcessor { bool continuation_test = true; ExecutionPolicy exec_until_jump; + ResumingManager& resuming_manager = ResumingManager::getInstance(); do { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 0); + }; m_node.children[0]->execute(exec_until_jump); if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { diff --git a/src/language/node_processor/ForProcessor.hpp b/src/language/node_processor/ForProcessor.hpp index 674c7d05e..47d50bc5a 100644 --- a/src/language/node_processor/ForProcessor.hpp +++ b/src/language/node_processor/ForProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class ForProcessor final : public INodeProcessor { @@ -21,8 +24,12 @@ class ForProcessor final : public INodeProcessor execute(ExecutionPolicy& exec_policy) { ExecutionPolicy exec_until_jump; - m_node.children[0]->execute(exec_policy); - while ([&]() { + ResumingManager& resuming_manager = ResumingManager::getInstance(); + + if (not resuming_manager.isResuming()) [[likely]] { + m_node.children[0]->execute(exec_policy); + } + while (resuming_manager.isResuming() or [&]() { return static_cast<bool>(std::visit( [](auto&& value) -> bool { using T = std::decay_t<decltype(value)>; @@ -34,6 +41,15 @@ class ForProcessor final : public INodeProcessor }, m_node.children[1]->execute(exec_policy))); }()) { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 3); + } m_node.children[3]->execute(exec_until_jump); if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { diff --git a/src/language/node_processor/IfProcessor.hpp b/src/language/node_processor/IfProcessor.hpp index 0d7497eaf..0c5d356b7 100644 --- a/src/language/node_processor/IfProcessor.hpp +++ b/src/language/node_processor/IfProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class IfProcessor final : public INodeProcessor { @@ -20,29 +23,41 @@ class IfProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - const bool is_true = static_cast<bool>(std::visit( // LCOV_EXCL_LINE (false negative) - [](const auto& value) -> bool { - using T = std::decay_t<decltype(value)>; - if constexpr (std::is_arithmetic_v<T>) { - return value; - } else { - return false; // LCOV_EXCL_LINE (unreachable: only there for compilation purpose) - } - }, - m_node.children[0]->execute(exec_policy))); - if (is_true) { - Assert(m_node.children[1] != nullptr); - m_node.children[1]->execute(exec_policy); - if (m_node.children[1]->m_symbol_table != m_node.m_symbol_table) - m_node.children[1]->m_symbol_table->clearValues(); + ResumingManager& resuming_manager = ResumingManager::getInstance(); + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + m_node.children[i_child]->execute(exec_policy); } else { - if (m_node.children.size() == 3) { - // else statement - Assert(m_node.children[2] != nullptr); - m_node.children[2]->execute(exec_policy); - if (m_node.children[2]->m_symbol_table != m_node.m_symbol_table) - m_node.children[2]->m_symbol_table->clearValues(); + const bool is_true = static_cast<bool>(std::visit( // LCOV_EXCL_LINE (false negative) + [](const auto& value) -> bool { + using T = std::decay_t<decltype(value)>; + if constexpr (std::is_arithmetic_v<T>) { + return value; + } else { + return false; // LCOV_EXCL_LINE (unreachable: only there for compilation purpose) + } + }, + m_node.children[0]->execute(exec_policy))); + if (is_true) { + Assert(m_node.children[1] != nullptr); + m_node.children[1]->execute(exec_policy); + if (m_node.children[1]->m_symbol_table != m_node.m_symbol_table) + m_node.children[1]->m_symbol_table->clearValues(); + + } else { + if (m_node.children.size() == 3) { + // else statement + Assert(m_node.children[2] != nullptr); + m_node.children[2]->execute(exec_policy); + if (m_node.children[2]->m_symbol_table != m_node.m_symbol_table) + m_node.children[2]->m_symbol_table->clearValues(); + } } } diff --git a/src/language/node_processor/WhileProcessor.hpp b/src/language/node_processor/WhileProcessor.hpp index 72c8f4186..86912e126 100644 --- a/src/language/node_processor/WhileProcessor.hpp +++ b/src/language/node_processor/WhileProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class WhileProcessor final : public INodeProcessor { @@ -21,7 +24,8 @@ class WhileProcessor final : public INodeProcessor execute(ExecutionPolicy& exec_policy) { ExecutionPolicy exec_until_jump; - while ([&]() { + ResumingManager& resuming_manager = ResumingManager::getInstance(); + while (resuming_manager.isResuming() or [&]() { return static_cast<bool>(std::visit( [](const auto& value) -> bool { using T = std::decay_t<decltype(value)>; @@ -33,7 +37,17 @@ class WhileProcessor final : public INodeProcessor }, m_node.children[0]->execute(exec_policy))); }()) { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 1); + } m_node.children[1]->execute(exec_until_jump); + if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { break; diff --git a/src/language/utils/ASTCheckpoint.hpp b/src/language/utils/ASTCheckpoint.hpp new file mode 100644 index 000000000..2114083da --- /dev/null +++ b/src/language/utils/ASTCheckpoint.hpp @@ -0,0 +1,36 @@ +#ifndef AST_CHECKPOINT_HPP +#define AST_CHECKPOINT_HPP + +#include <vector> + +class ASTNode; +class ASTCheckpoint +{ + private: + const ASTNode* const m_p_node; + const std::vector<size_t> m_ast_location; + + public: + const ASTNode& + node() const + { + return *m_p_node; + } + + const std::vector<size_t>& + getASTLocation() const + { + return m_ast_location; + } + + ASTCheckpoint(const std::vector<size_t>& ast_location, const ASTNode* const p_node) + : m_p_node{p_node}, m_ast_location{ast_location} + {} + + ASTCheckpoint(const ASTCheckpoint&) = default; + ASTCheckpoint(ASTCheckpoint&&) = default; + + ~ASTCheckpoint() = default; +}; + +#endif // AST_CHECKPOINT_HPP diff --git a/src/language/utils/ASTCheckpointsInfo.cpp b/src/language/utils/ASTCheckpointsInfo.cpp new file mode 100644 index 000000000..779290bba --- /dev/null +++ b/src/language/utils/ASTCheckpointsInfo.cpp @@ -0,0 +1,53 @@ +#include <language/utils/ASTCheckpointsInfo.hpp> + +#include <language/PEGGrammar.hpp> +#include <language/ast/ASTNode.hpp> + +const ASTCheckpointsInfo* ASTCheckpointsInfo::m_checkpoints_info_instance = nullptr; + +void +ASTCheckpointsInfo::_findASTCheckpoint(std::vector<size_t>& location, const ASTNode& node) +{ + if (node.is_type<language::function_evaluation>()) { + const ASTNode& node_name = *node.children[0]; + if (node_name.is_type<language::name>() and (node_name.m_data_type == ASTNodeDataType::builtin_function_t) and + ((node_name.string() == "checkpoint") or (node_name.string() == "checkpoint_and_exit"))) { + m_ast_checkpoint_list.push_back(ASTCheckpoint{location, &node}); + return; + } + } + + if (node.children.size() > 0) { + location.push_back(0); + for (size_t i_node = 0; i_node < node.children.size(); ++i_node) { + location[location.size() - 1] = i_node; + this->_findASTCheckpoint(location, *node.children[i_node]); + } + location.pop_back(); + } +} + +ASTCheckpointsInfo::ASTCheckpointsInfo(const ASTNode& root_node) +{ + Assert(m_checkpoints_info_instance == nullptr, "Can only define one ASTCheckpointInfo"); + m_checkpoints_info_instance = this; + + Assert(root_node.is_root()); + + std::vector<size_t> location; + this->_findASTCheckpoint(location, root_node); + + Assert(location.size() == 0); +} + +const ASTCheckpointsInfo& +ASTCheckpointsInfo::getInstance() +{ + Assert(m_checkpoints_info_instance != nullptr, "ASTCheckpointInfo is not defined!"); + return *m_checkpoints_info_instance; +} + +ASTCheckpointsInfo::~ASTCheckpointsInfo() +{ + m_checkpoints_info_instance = nullptr; +} diff --git a/src/language/utils/ASTCheckpointsInfo.hpp b/src/language/utils/ASTCheckpointsInfo.hpp new file mode 100644 index 000000000..ca10d3a58 --- /dev/null +++ b/src/language/utils/ASTCheckpointsInfo.hpp @@ -0,0 +1,53 @@ +#ifndef AST_CHECKPOINTS_INFO_HPP +#define AST_CHECKPOINTS_INFO_HPP + +#include <language/utils/ASTCheckpoint.hpp> +#include <utils/Exceptions.hpp> +#include <utils/PugsAssert.hpp> + +#include <string> +#include <vector> + +class ASTNode; + +class ASTCheckpointsInfo +{ + private: + static const ASTCheckpointsInfo* m_checkpoints_info_instance; + + std::vector<ASTCheckpoint> m_ast_checkpoint_list; + + void _findASTCheckpoint(std::vector<size_t>& location, const ASTNode& node); + + // The only place where the ASTCheckpointsInfo can be built + friend void parser(const std::string& filename); + + ASTCheckpointsInfo(const ASTNode& root_node); + + public: + size_t + getCheckpointId(const ASTNode& node) const + { + for (size_t i = 0; i < m_ast_checkpoint_list.size(); ++i) { + if (&m_ast_checkpoint_list[i].node() == &node) { + return i; + } + } + throw UnexpectedError("Could not find node"); + } + + const ASTCheckpoint& + getASTCheckpoint(size_t checkpoint_id) const + { + Assert(checkpoint_id < m_ast_checkpoint_list.size()); + return m_ast_checkpoint_list[checkpoint_id]; + } + + static const ASTCheckpointsInfo& getInstance(); + + ASTCheckpointsInfo() = delete; + + ~ASTCheckpointsInfo(); +}; + +#endif // AST_CHECKPOINTS_INFO_HPP diff --git a/src/language/utils/ASTExecutionInfo.cpp b/src/language/utils/ASTExecutionInfo.cpp index 6015b6a8f..8a3131a6e 100644 --- a/src/language/utils/ASTExecutionInfo.cpp +++ b/src/language/utils/ASTExecutionInfo.cpp @@ -2,24 +2,24 @@ #include <language/ast/ASTNode.hpp> -const ASTExecutionInfo* ASTExecutionInfo::m_current_execution_info = nullptr; +const ASTExecutionInfo* ASTExecutionInfo::m_execution_info_instance = nullptr; ASTExecutionInfo::ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_repository) : m_root_node{root_node}, m_module_repository{module_repository} { - Assert(m_current_execution_info == nullptr, "Can only define one ASTExecutionInfo"); + Assert(m_execution_info_instance == nullptr, "Can only define one ASTExecutionInfo"); - m_current_execution_info = this; + m_execution_info_instance = this; } const ASTExecutionInfo& -ASTExecutionInfo::current() +ASTExecutionInfo::getInstance() { - Assert(m_current_execution_info != nullptr, "ASTExecutionInfo is not defined!"); - return *m_current_execution_info; + Assert(m_execution_info_instance != nullptr, "ASTExecutionInfo is not defined!"); + return *m_execution_info_instance; } ASTExecutionInfo::~ASTExecutionInfo() { - m_current_execution_info = nullptr; + m_execution_info_instance = nullptr; } diff --git a/src/language/utils/ASTExecutionInfo.hpp b/src/language/utils/ASTExecutionInfo.hpp index da4b81742..1f389761a 100644 --- a/src/language/utils/ASTExecutionInfo.hpp +++ b/src/language/utils/ASTExecutionInfo.hpp @@ -8,7 +8,7 @@ class ASTNode; class ASTExecutionInfo { private: - static const ASTExecutionInfo* m_current_execution_info; + static const ASTExecutionInfo* m_execution_info_instance; const ASTNode& m_root_node; @@ -34,7 +34,7 @@ class ASTExecutionInfo return m_module_repository; } - static const ASTExecutionInfo& current(); + static const ASTExecutionInfo& getInstance(); ASTExecutionInfo() = delete; diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 33a41962e..0455cd053 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -8,6 +8,7 @@ add_library(PugsLanguageUtils AffectationRegisterForRnxn.cpp AffectationRegisterForString.cpp AffectationRegisterForZ.cpp + ASTCheckpointsInfo.cpp ASTDotPrinter.cpp ASTExecutionInfo.cpp ASTNodeDataType.cpp diff --git a/src/language/utils/SymbolTable.hpp b/src/language/utils/SymbolTable.hpp index 9f7bc8129..38351f7ec 100644 --- a/src/language/utils/SymbolTable.hpp +++ b/src/language/utils/SymbolTable.hpp @@ -24,6 +24,7 @@ class SymbolTable int32_t m_context_id; bool m_is_initialized{false}; + bool m_is_module_variable{false}; // variable that is created by a module (ex: cout,...) ASTNodeDataType m_data_type; DataVariant m_value; @@ -35,7 +36,7 @@ class SymbolTable return m_context_id != -1; } - const int32_t& + int32_t contextId() const { return m_context_id; @@ -53,7 +54,7 @@ class SymbolTable return m_value; } - const bool& + bool isInitialized() const { return m_is_initialized; @@ -65,6 +66,18 @@ class SymbolTable m_is_initialized = true; } + bool + isModuleVariable() const + { + return m_is_module_variable; + } + + void + setIsModuleVariable() + { + m_is_module_variable = true; + } + const ASTNodeDataType& dataType() const { @@ -98,7 +111,7 @@ class SymbolTable return os; } - Attributes& operator=(Attributes&&) = default; + Attributes& operator=(Attributes&&) = default; Attributes& operator=(const Attributes&) = default; Attributes(const TAO_PEGTL_NAMESPACE::position& position, int32_t context_id) @@ -138,7 +151,7 @@ class SymbolTable Symbol(const std::string& name, const Attributes& attributes) : m_name(name), m_attributes(attributes) {} - Symbol& operator=(Symbol&&) = default; + Symbol& operator=(Symbol&&) = default; Symbol& operator=(const Symbol&) = default; Symbol(const Symbol&) = default; @@ -179,7 +192,7 @@ class SymbolTable Context() : m_id{next_context_id++} {} Context& operator=(const Context&) = default; // clazy:exclude=function-args-by-value - Context& operator=(Context&&) = default; + Context& operator=(Context&&) = default; Context(const Context&) = default; Context(Context&&) = default; @@ -198,6 +211,24 @@ class SymbolTable std::shared_ptr<EmbedderTable<TypeDescriptor>> m_type_embedder_table; public: + const std::vector<Symbol>& + symbolList() const + { + return m_symbol_list; + } + + bool + hasParentTable() const + { + return m_parent_table.use_count() > 0; + } + + const std::shared_ptr<SymbolTable> + parentTable() const + { + return m_parent_table; + } + bool hasContext() const { diff --git a/src/main.cpp b/src/main.cpp index 7d0112c80..a373bd2cd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -9,11 +9,13 @@ #include <utils/GlobalVariableManager.hpp> #include <utils/PugsUtils.hpp> #include <utils/RandomEngine.hpp> +#include <utils/checkpointing/ResumingManager.hpp> int main(int argc, char* argv[]) { ExecutionStatManager::create(); + ResumingManager::create(); ParallelChecker::create(); std::string filename = initialize(argc, argv); @@ -42,6 +44,7 @@ main(int argc, char* argv[]) finalize(); ParallelChecker::destroy(); + ResumingManager::destroy(); int return_code = ExecutionStatManager::getInstance().exitCode(); ExecutionStatManager::destroy(); diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index b5a383c9c..f06fa6871 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -1,5 +1,7 @@ # ------------------- Source files -------------------- +add_subdirectory(checkpointing) + add_library( PugsUtils BuildInfo.cpp diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index 104e70cdd..f1d4f488e 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -13,6 +13,7 @@ #include <utils/RevisionInfo.hpp> #include <utils/SLEPcWrapper.hpp> #include <utils/SignalManager.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <utils/pugs_build_info.hpp> #include <rang.hpp> @@ -98,6 +99,9 @@ initialize(int& argc, char* argv[]) app.add_option("filename", filename, "pugs script file")->check(CLI::ExistingFile)->required(); + bool is_resuming = false; + app.add_flag("--resume", is_resuming, "Resume last checkpoint"); + app.set_version_flag("-v,--version", []() { ConsoleManager::init(true); std::stringstream os; @@ -162,6 +166,9 @@ initialize(int& argc, char* argv[]) CommunicatorManager::setSplitColor(mpi_split_color); } + ResumingManager::getInstance().setIsResuming(is_resuming); + ResumingManager::getInstance().setFilename(filename); + ExecutionStatManager::getInstance().setPrint(print_exec_stat); BacktraceManager::setShow(show_backtrace); ConsoleManager::setShowPreamble(show_preamble); diff --git a/src/utils/checkpointing/CMakeLists.txt b/src/utils/checkpointing/CMakeLists.txt new file mode 100644 index 000000000..a2a0cdc82 --- /dev/null +++ b/src/utils/checkpointing/CMakeLists.txt @@ -0,0 +1,18 @@ +# ------------------- Source files -------------------- + +add_library( + PugsCheckpointing + Checkpoint.cpp + Resume.cpp + ResumingManager.cpp + ResumingUtils.cpp) + +# Additional dependencies +add_dependencies(PugsCheckpointing + PugsLanguageAST + PugsUtils) + +target_link_libraries( + PugsCheckpointing + ${HIGHFIVE_TARGET} +) diff --git a/src/utils/checkpointing/Checkpoint.cpp b/src/utils/checkpointing/Checkpoint.cpp new file mode 100644 index 000000000..16ea2d830 --- /dev/null +++ b/src/utils/checkpointing/Checkpoint.cpp @@ -0,0 +1,108 @@ +#include <utils/checkpointing/Checkpoint.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/ast/ASTExecutionStack.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <iostream> +#endif // PUGS_HAS_HDF5 + +#include <language/utils/ASTCheckpointsInfo.hpp> +#include <utils/Exceptions.hpp> +#include <utils/checkpointing/ResumingManager.hpp> + +#ifdef PUGS_HAS_HDF5 +void +checkpoint() +{ + auto create_props = HighFive::FileCreateProps{}; + create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); + + uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); + + const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; + + HighFive::File file("checkpoint.h5", file_openmode, create_props); + + HighFive::Group checkpoint = file.createGroup("checkpoint_" + std::to_string(checkpoint_number)); + + uint64_t checkpoint_id = + ASTCheckpointsInfo::getInstance().getCheckpointId((ASTExecutionStack::getInstance().currentNode())); + + checkpoint.createAttribute("checkpoint_id", checkpoint_id); + checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); + + std::shared_ptr<const SymbolTable> p_symbol_table = ASTExecutionStack::getInstance().currentNode().m_symbol_table; + auto symbol_table_group = checkpoint; + while (p_symbol_table.use_count() > 0) { + symbol_table_group = symbol_table_group.createGroup("symbol table"); + + const SymbolTable& symbol_table = *p_symbol_table; + + const auto& symbol_list = symbol_table.symbolList(); + + for (auto& symbol : symbol_list) { + switch (symbol.attributes().dataType()) { + case ASTNodeDataType::builtin_function_t: + case ASTNodeDataType::function_t: + case ASTNodeDataType::type_name_id_t: { + break; + } + default: { + if ((symbol.attributes().dataType() != ASTNodeDataType::builtin_function_t) and + (not symbol.attributes().isModuleVariable())) { + std::visit( + [&](auto&& data) { + using DataT = std::decay_t<decltype(data)>; + if constexpr (std::is_same_v<DataT, std::monostate>) { + } else if constexpr ((std::is_arithmetic_v<DataT>) or (std::is_same_v<DataT, std::string>) or + (is_tiny_vector_v<DataT>) or (is_tiny_matrix_v<DataT>)) { + symbol_table_group.createAttribute(symbol.name(), data); + } else if constexpr (is_std_vector_v<DataT>) { + using value_type = typename DataT::value_type; + if constexpr ((std::is_arithmetic_v<value_type>) or (std::is_same_v<value_type, std::string>) or + (is_tiny_vector_v<value_type>) or (is_tiny_matrix_v<value_type>)) { + symbol_table_group.createAttribute(symbol.name(), data); + } else { + throw NotImplementedError("datatype is not handled yet"); + } + } else { + throw NotImplementedError("datatype is not handled yet"); + } + }, + symbol.attributes().value()); + } + } + } + } + + p_symbol_table = symbol_table.parentTable(); + } + + if (file.exist("last_checkpoint")) { + file.unlink("last_checkpoint"); + } + file.createHardLink("last_checkpoint", checkpoint); + + if (file.hasAttribute("checkpoint_number")) { + file.deleteAttribute("checkpoint_number"); + } + file.createAttribute("checkpoint_number", checkpoint_number); + + ++checkpoint_number; +} + +#else // PUGS_HAS_HDF5 + +void +checkpoint() +{ + throw NormalError("checkpoint/resume mechanism requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/Checkpoint.hpp b/src/utils/checkpointing/Checkpoint.hpp new file mode 100644 index 000000000..d5c4e9100 --- /dev/null +++ b/src/utils/checkpointing/Checkpoint.hpp @@ -0,0 +1,6 @@ +#ifndef CHECKPOINT_HPP +#define CHECKPOINT_HPP + +void checkpoint(); + +#endif // CHECKPOINT_HPP diff --git a/src/utils/checkpointing/Resume.cpp b/src/utils/checkpointing/Resume.cpp new file mode 100644 index 000000000..56c3c2834 --- /dev/null +++ b/src/utils/checkpointing/Resume.cpp @@ -0,0 +1,229 @@ +#include <utils/checkpointing/Resume.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/ast/ASTExecutionStack.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <iostream> +#endif // PUGS_HAS_HDF5 + +#include <language/utils/ASTCheckpointsInfo.hpp> +#include <utils/Exceptions.hpp> +#include <utils/checkpointing/ResumingManager.hpp> + +#ifdef PUGS_HAS_HDF5 + +void +resume() +{ + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/last_checkpoint"); + + HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table"); + + const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode(); + auto p_symbol_table = p_node->m_symbol_table; + + ResumingManager& resuming_manager = ResumingManager::getInstance(); + + resuming_manager.checkpointNumber() = file.getAttribute("checkpoint_number").read<uint64_t>(); + + std::cout << " * " + << "Using " << rang::fgB::green << "checkpoint" << rang::fg::reset << " number " + << resuming_manager.checkpointNumber()++ << '\n'; + + std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line " << rang::fgB::yellow + << p_node->begin().line << rang::fg::reset << " [checkpoint id " << rang::fgB::cyan + << resuming_manager.checkpointId() << rang::fg::reset << "]\n"; + + bool finished = true; + do { + finished = true; + + for (auto symbol_name : saved_symbol_table.listAttributeNames()) { + auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin()); + auto& attribute = p_symbol->attributes(); + switch (attribute.dataType()) { + case ASTNodeDataType::bool_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>(); + break; + } + case ASTNodeDataType::unsigned_int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>(); + break; + } + case ASTNodeDataType::int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>(); + break; + } + case ASTNodeDataType::double_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>(); + break; + } + case ASTNodeDataType::string_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>(); + break; + } + case ASTNodeDataType::vector_t: { + switch (attribute.dataType().dimension()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::matrix_t: { + // LCOV_EXCL_START + if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + switch (attribute.dataType().numberOfRows()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::tuple_t: { + switch (attribute.dataType().contentType()) { + case ASTNodeDataType::bool_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>(); + break; + } + case ASTNodeDataType::unsigned_int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>(); + break; + } + case ASTNodeDataType::int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>(); + break; + } + case ASTNodeDataType::double_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>(); + break; + } + case ASTNodeDataType::string_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>(); + break; + } + case ASTNodeDataType::vector_t: { + switch (attribute.dataType().contentType().dimension()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::matrix_t: { + // LCOV_EXCL_START + if (attribute.dataType().contentType().numberOfRows() != + attribute.dataType().contentType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + switch (attribute.dataType().contentType().numberOfRows()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + default: { + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType())); + } + } + break; + } + default: { + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType())); + } + } + } + + const bool symbol_table_has_parent = p_symbol_table->hasParentTable(); + const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table"); + + Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent)); + + if (symbol_table_has_parent and saved_symbol_table_has_parent) { + p_symbol_table = p_symbol_table->parentTable(); + saved_symbol_table = saved_symbol_table.getGroup("symbol table"); + + finished = false; + } + + } while (not finished); +} + +#else // PUGS_HAS_HDF5 + +void +resume() +{ + throw NormalError("checkpoint/resume mechanism requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/Resume.hpp b/src/utils/checkpointing/Resume.hpp new file mode 100644 index 000000000..9512539bc --- /dev/null +++ b/src/utils/checkpointing/Resume.hpp @@ -0,0 +1,6 @@ +#ifndef RESUME_HPP +#define RESUME_HPP + +void resume(); + +#endif // RESUME_HPP diff --git a/src/utils/checkpointing/ResumingManager.cpp b/src/utils/checkpointing/ResumingManager.cpp new file mode 100644 index 000000000..f4e8d6810 --- /dev/null +++ b/src/utils/checkpointing/ResumingManager.cpp @@ -0,0 +1,44 @@ +#include <utils/checkpointing/ResumingManager.hpp> + +#include <utils/HighFivePugsUtils.hpp> + +ResumingManager* ResumingManager::m_instance = nullptr; + +ResumingManager& +ResumingManager::getInstance() +{ + Assert(m_instance != nullptr, "instance was not created"); + return *m_instance; +} + +void +ResumingManager::create() +{ + Assert(m_instance == nullptr, "Resuming manager was already created"); + m_instance = new ResumingManager; +} + +void +ResumingManager::destroy() +{ + Assert(m_instance != nullptr, "Resuming manager was not created"); + delete m_instance; + m_instance = nullptr; +} + +size_t +ResumingManager::checkpointId() +{ +#ifdef PUGS_HAS_HDF5 + if (not m_checkpoint_id) { + HighFive::File file(m_filename, HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/last_checkpoint"); + + m_checkpoint_id = std::make_unique<size_t>(checkpoint.getAttribute("checkpoint_id").read<uint64_t>()); + } +#else // PUGS_HAS_HDF5 + m_checkpoint_id = std::make_unique<uint64_t>(0); +#endif // PUGS_HAS_HDF5 + return *m_checkpoint_id; +} diff --git a/src/utils/checkpointing/ResumingManager.hpp b/src/utils/checkpointing/ResumingManager.hpp new file mode 100644 index 000000000..a39c68d02 --- /dev/null +++ b/src/utils/checkpointing/ResumingManager.hpp @@ -0,0 +1,72 @@ +#ifndef RESUMING_MANAGER_HPP +#define RESUMING_MANAGER_HPP + +#include <utils/PugsAssert.hpp> + +#include <memory> + +class ResumingManager +{ + private: + bool m_is_resuming = false; + std::string m_filename; + // numbering of the checkpoints (during execution) + uint64_t m_checkpoint_number = 0; + // id of the checkpoint defined in the script file + std::unique_ptr<uint64_t> m_checkpoint_id; + size_t m_current_ast_level = 0; + + ResumingManager() = default; + + ResumingManager(ResumingManager&&) = delete; + ResumingManager(const ResumingManager&) = delete; + + ~ResumingManager() = default; + + static ResumingManager* m_instance; + + public: + static void create(); + static void destroy(); + static ResumingManager& getInstance(); + + uint64_t& + checkpointNumber() + { + return m_checkpoint_number; + } + + uint64_t checkpointId(); + + uint64_t& + currentASTLevel() + { + return m_current_ast_level; + } + + void + setIsResuming(const bool is_resuming) + { + m_is_resuming = is_resuming; + } + + bool + isResuming() const + { + return m_is_resuming; + } + + void + setFilename(const std::string& filename) + { + m_filename = filename; + } + + const std::string& + filename() const + { + return m_filename; + } +}; + +#endif // RESUMING_MANAGER_HPP diff --git a/src/utils/checkpointing/ResumingUtils.cpp b/src/utils/checkpointing/ResumingUtils.cpp new file mode 100644 index 000000000..af811b807 --- /dev/null +++ b/src/utils/checkpointing/ResumingUtils.cpp @@ -0,0 +1,27 @@ +#include <utils/checkpointing/ResumingUtils.hpp> + +#include <utils/Exceptions.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +std::string +resumingDatafile(const std::string& filename) +{ + HighFive::File file(filename, HighFive::File::ReadOnly); +#warning change when checkpoint id can be specified + return file.getGroup("/last_checkpoint").getDataSet("data.pgs").read<std::string>(); +} + +#else // PUGS_HAS_HDF5 + +std::string +resumingDatafile(const std::string& filename) +{ + throw NormalError("Resuming requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/ResumingUtils.hpp b/src/utils/checkpointing/ResumingUtils.hpp new file mode 100644 index 000000000..46cd27ed7 --- /dev/null +++ b/src/utils/checkpointing/ResumingUtils.hpp @@ -0,0 +1,8 @@ +#ifndef RESUMING_UTILS_HPP +#define RESUMING_UTILS_HPP + +#include <string> + +std::string resumingDatafile(const std::string& filename); + +#endif // RESUMING_UTILS_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index aaace485d..9f0b29d5f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -234,6 +234,7 @@ target_link_libraries (unit_tests PugsScheme PugsOutput PugsUtils + PugsCheckpointing PugsDev Kokkos::kokkos ${PARMETIS_LIBRARIES} @@ -261,8 +262,9 @@ target_link_libraries (mpi_unit_tests PugsLanguageUtils PugsScheme PugsOutput - PugsDev PugsUtils + PugsCheckpointing + PugsDev PugsAlgebra PugsMesh Kokkos::kokkos diff --git a/tests/mpi_test_main.cpp b/tests/mpi_test_main.cpp index fb0da8569..083dce9b1 100644 --- a/tests/mpi_test_main.cpp +++ b/tests/mpi_test_main.cpp @@ -13,6 +13,7 @@ #include <utils/PETScWrapper.hpp> #include <utils/RandomEngine.hpp> #include <utils/Stringify.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <utils/pugs_config.hpp> #include <MeshDataBaseForTests.hpp> @@ -93,6 +94,8 @@ main(int argc, char* argv[]) // Disable outputs from tested classes to the standard output std::cout.setstate(std::ios::badbit); + ResumingManager::create(); + SynchronizerManager::create(); RandomEngine::create(); QuadratureManager::create(); @@ -118,6 +121,8 @@ main(int argc, char* argv[]) QuadratureManager::destroy(); RandomEngine::destroy(); SynchronizerManager::destroy(); + + ResumingManager::destroy(); } } diff --git a/tests/test_ASTModulesImporter.cpp b/tests/test_ASTModulesImporter.cpp index 24557c449..6156c571e 100644 --- a/tests/test_ASTModulesImporter.cpp +++ b/tests/test_ASTModulesImporter.cpp @@ -19,7 +19,7 @@ test_ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_r REQUIRE(&root_node == &execution_info.rootNode()); REQUIRE(&module_repository == &execution_info.moduleRepository()); - REQUIRE(&ASTExecutionInfo::current() == &execution_info); + REQUIRE(&ASTExecutionInfo::getInstance() == &execution_info); } #define CHECK_AST(data, expected_output) \ diff --git a/tests/test_main.cpp b/tests/test_main.cpp index d8641d318..fe5e8eea3 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -13,6 +13,7 @@ #include <utils/PETScWrapper.hpp> #include <utils/RandomEngine.hpp> #include <utils/SLEPcWrapper.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <MeshDataBaseForTests.hpp> @@ -52,6 +53,8 @@ main(int argc, char* argv[]) // Disable outputs from tested classes to the standard output std::cout.setstate(std::ios::badbit); + ResumingManager::create(); + SynchronizerManager::create(); RandomEngine::create(); QuadratureManager::create(); @@ -77,6 +80,8 @@ main(int argc, char* argv[]) QuadratureManager::destroy(); RandomEngine::destroy(); SynchronizerManager::destroy(); + + ResumingManager::destroy(); } } -- GitLab