diff --git a/CMakeLists.txt b/CMakeLists.txt index 560dd74c04bf5f76dee000a9a8883276ff99bdfd..3fc961a86dc2895f361085b9756dbc9a1928f670 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -597,6 +597,7 @@ target_link_libraries( PugsMesh PugsAlgebra PugsAnalysis + PugsCheckpointing PugsDev PugsUtils PugsLanguage @@ -609,6 +610,7 @@ target_link_libraries( PugsUtils PugsOutput PugsLanguageUtils + PugsCheckpointing Kokkos::kokkos ${PETSC_LIBRARIES} ${SLEPC_LIBRARIES} @@ -635,6 +637,7 @@ install(TARGETS PugsMesh PugsAlgebra PugsAnalysis + PugsCheckpointing PugsDev PugsUtils PugsLanguage diff --git a/src/language/PugsParser.cpp b/src/language/PugsParser.cpp index 3e193aef2449038dde3b9dd7843c0f3dc76eeca4..b5b446f69ea47d1827c800afd7bfcb87139efe45 100644 --- a/src/language/PugsParser.cpp +++ b/src/language/PugsParser.cpp @@ -14,9 +14,8 @@ #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolInitializationChecker.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> -#include <language/utils/ASTDotPrinter.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/ASTExecutionInfo.hpp> -#include <language/utils/ASTPrinter.hpp> #include <language/utils/Exit.hpp> #include <language/utils/OperatorRepository.hpp> #include <language/utils/SymbolTable.hpp> @@ -26,6 +25,8 @@ #include <utils/PugsAssert.hpp> #include <utils/PugsUtils.hpp> #include <utils/SignalManager.hpp> +#include <utils/checkpointing/ResumingManager.hpp> +#include <utils/checkpointing/ResumingUtils.hpp> #include <pegtl/contrib/analyze.hpp> #include <pegtl/contrib/parse_tree.hpp> @@ -92,6 +93,8 @@ parser(const std::string& filename) ASTExecutionInfo execution_info{*root_node, module_importer.moduleRepository()}; + ASTCheckpointsInfo checkpoint_info{*root_node}; + ExecutionPolicy exec_all; try { root_node->execute(exec_all); @@ -99,6 +102,7 @@ parser(const std::string& filename) catch (language::Exit& e) { ExecutionStatManager::getInstance().setExitCode(e.code()); } + root_node->m_symbol_table->clearValues(); OperatorRepository::destroy(); @@ -106,8 +110,12 @@ parser(const std::string& filename) std::string file_content; if (parallel::rank() == 0) { - std::ifstream file{filename}; - file_content = std::string{std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; + if (ResumingManager::getInstance().isResuming()) { + file_content = resumingDatafile(filename); + } else { + std::ifstream file{filename}; + file_content = std::string{std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()}; + } } parallel::broadcast(file_content, 0); diff --git a/src/language/ast/ASTExecutionStack.cpp b/src/language/ast/ASTExecutionStack.cpp index 922e40e32bd8901b226561426b7b42976babae7a..48add07ecfc0739bbec88a15fe38153abc6ad1ba 100644 --- a/src/language/ast/ASTExecutionStack.cpp +++ b/src/language/ast/ASTExecutionStack.cpp @@ -15,13 +15,17 @@ ASTExecutionStack::errorMessageAt(const std::string& message) const auto& stack = ASTExecutionStack::getInstance().m_stack; std::ostringstream error_msg; - auto p = stack[stack.size() - 1]->begin(); - error_msg << rang::style::bold << p.source << ':' << p.line << ':' << p.column << ": " << rang::style::reset - << message << rang::fg::reset << '\n'; + if (stack.size() > 0) { + auto p = stack[stack.size() - 1]->begin(); + error_msg << rang::style::bold << p.source << ':' << p.line << ':' << p.column << ": " << rang::style::reset + << message << rang::fg::reset << '\n'; - if (m_file_input.use_count() > 0) { - error_msg << m_file_input->line_at(p) << '\n' - << std::string(p.column - 1, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; + if (m_file_input.use_count() > 0) { + error_msg << m_file_input->line_at(p) << '\n' + << std::string(p.column - 1, ' ') << rang::fgB::yellow << '^' << rang::fg::reset << '\n'; + } + } else { + error_msg << message; } return error_msg.str(); diff --git a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 169540df28558e1b25a54acbf751bac4488a9449..ad4ced30f2b20e365ffd65251b99c66a32b34a25 100644 --- a/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ast/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -281,7 +281,7 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData } // LCOV_EXCL_START default: { - throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector"); + throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of matrix"); } // LCOV_EXCL_STOP } diff --git a/src/language/modules/CMakeLists.txt b/src/language/modules/CMakeLists.txt index 5f87dcf34f5b8652f976a2f88d851eb381d8385b..1ecbe1db54c85a4282ed1a55a043115f09bc7525 100644 --- a/src/language/modules/CMakeLists.txt +++ b/src/language/modules/CMakeLists.txt @@ -23,6 +23,7 @@ target_link_libraries( ) add_dependencies( + PugsCheckpointing PugsLanguageModules PugsLanguageAlgorithms PugsUtils diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp index e9be93f3c1cc96af47d7665317d5a364cd54f17e..72b46b30153fe3c6a7291eaef6cedab820231c95 100644 --- a/src/language/modules/CoreModule.cpp +++ b/src/language/modules/CoreModule.cpp @@ -33,6 +33,9 @@ #include <utils/Messenger.hpp> #include <utils/PugsUtils.hpp> #include <utils/RandomEngine.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/Resume.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <random> @@ -54,7 +57,7 @@ CoreModule::CoreModule() : BuiltinModule(true) []() -> std::string { const ModuleRepository& repository = - ASTExecutionInfo::current().moduleRepository(); + ASTExecutionInfo::getInstance().moduleRepository(); return repository.getAvailableModules(); } @@ -65,7 +68,7 @@ CoreModule::CoreModule() : BuiltinModule(true) [](const std::string& module_name) -> std::string { const ModuleRepository& repository = - ASTExecutionInfo::current().moduleRepository(); + ASTExecutionInfo::getInstance().moduleRepository(); return repository.getModuleInfo(module_name); } @@ -99,7 +102,7 @@ CoreModule::CoreModule() : BuiltinModule(true) this->_addBuiltinFunction("exit", std::function( [](const int64_t& exit_code) -> void { - const auto& location = ASTBacktrace::getInstance().sourceLocation(); + const auto& location = ASTExecutionStack::getInstance().sourceLocation(); std::cout << "\n** " << rang::fgB::yellow << "exit" << rang::fg::reset << " explicitly called with code " << rang::fgB::cyan << exit_code << rang::fg::reset << "\n from " << rang::style::underline @@ -111,6 +114,33 @@ CoreModule::CoreModule() : BuiltinModule(true) )); + this->_addBuiltinFunction("checkpoint", std::function( + + []() -> void { + if (ResumingManager::getInstance().isResuming()) { + resume(); + ResumingManager::getInstance().setIsResuming(false); + } else { + checkpoint(); + } + } + + )); + + this->_addBuiltinFunction("checkpoint_and_exit", std::function( + + []() -> void { + if (ResumingManager::getInstance().isResuming()) { + resume(); + ResumingManager::getInstance().setIsResuming(false); + } else { + checkpoint(); + throw language::Exit(0); + } + } + + )); + this->_addNameValue("cout", ast_node_data_type_from<std::shared_ptr<const OStream>>, EmbeddedData{std::make_shared<DataHandler<const OStream>>(std::make_shared<OStream>(std::cout))}); diff --git a/src/language/modules/DevUtilsModule.cpp b/src/language/modules/DevUtilsModule.cpp index 7df6e847e157a68e073642dd50ee87383788d98a..a1738294600a5048893691217e01655d308ca6aa 100644 --- a/src/language/modules/DevUtilsModule.cpp +++ b/src/language/modules/DevUtilsModule.cpp @@ -39,7 +39,7 @@ DevUtilsModule::DevUtilsModule() this->_addBuiltinFunction("getAST", std::function( []() -> std::string { - const auto& root_node = ASTExecutionInfo::current().rootNode(); + const auto& root_node = ASTExecutionInfo::getInstance().rootNode(); std::ostringstream os; os << ASTPrinter{root_node}; @@ -52,7 +52,7 @@ DevUtilsModule::DevUtilsModule() this->_addBuiltinFunction("saveASTDot", std::function( [](const std::string& dot_filename) -> void { - const auto& root_node = ASTExecutionInfo::current().rootNode(); + const auto& root_node = ASTExecutionInfo::getInstance().rootNode(); std::ofstream fout(dot_filename); diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index 5bdad8bceff609288e318cce8b3f37304afc69b5..e1e53796f31b117fee9f38c20cfbc6e8ff030349 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -108,6 +108,7 @@ ModuleRepository::_populateSymbolTable(const ASTNode& module_node, i_symbol->attributes().setDataType(value_descriptor->type()); i_symbol->attributes().setIsInitialized(); + i_symbol->attributes().setIsModuleVariable(); i_symbol->attributes().value() = value_descriptor->value(); } } diff --git a/src/language/node_processor/ASTNodeListProcessor.hpp b/src/language/node_processor/ASTNodeListProcessor.hpp index 142d33f7d7ce085b57c3a3746f9f559c87e4408e..e9564c9a3a2f890d0f9fe35d70c1f47147d9a7c1 100644 --- a/src/language/node_processor/ASTNodeListProcessor.hpp +++ b/src/language/node_processor/ASTNodeListProcessor.hpp @@ -4,7 +4,10 @@ #include <language/PEGGrammar.hpp> #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class ASTNodeListProcessor final : public INodeProcessor { @@ -21,8 +24,21 @@ class ASTNodeListProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - for (auto& child : m_node.children) { - child->execute(exec_policy); + ResumingManager& resuming_manager = ResumingManager::getInstance(); + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + for (size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + i_child < m_node.children.size(); ++i_child) { + m_node.children[i_child]->execute(exec_policy); + } + } else { + for (auto&& child : m_node.children) { + child->execute(exec_policy); + } } if (not(m_node.is_root() or m_node.is_type<language::for_statement_block>())) diff --git a/src/language/node_processor/DoWhileProcessor.hpp b/src/language/node_processor/DoWhileProcessor.hpp index 593fd4bf178ea0734ce4e512fd913c4cc760ee98..cbda2ca266991575c2de12b00980a98381b7d5ff 100644 --- a/src/language/node_processor/DoWhileProcessor.hpp +++ b/src/language/node_processor/DoWhileProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class DoWhileProcessor final : public INodeProcessor { @@ -22,7 +25,17 @@ class DoWhileProcessor final : public INodeProcessor { bool continuation_test = true; ExecutionPolicy exec_until_jump; + ResumingManager& resuming_manager = ResumingManager::getInstance(); do { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 0); + }; m_node.children[0]->execute(exec_until_jump); if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { diff --git a/src/language/node_processor/ForProcessor.hpp b/src/language/node_processor/ForProcessor.hpp index 674c7d05ea4e6f89c4941eddc15540556961ae87..47d50bc5abd9dfdbf22793e484c7d66f0a7c5505 100644 --- a/src/language/node_processor/ForProcessor.hpp +++ b/src/language/node_processor/ForProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class ForProcessor final : public INodeProcessor { @@ -21,8 +24,12 @@ class ForProcessor final : public INodeProcessor execute(ExecutionPolicy& exec_policy) { ExecutionPolicy exec_until_jump; - m_node.children[0]->execute(exec_policy); - while ([&]() { + ResumingManager& resuming_manager = ResumingManager::getInstance(); + + if (not resuming_manager.isResuming()) [[likely]] { + m_node.children[0]->execute(exec_policy); + } + while (resuming_manager.isResuming() or [&]() { return static_cast<bool>(std::visit( [](auto&& value) -> bool { using T = std::decay_t<decltype(value)>; @@ -34,6 +41,15 @@ class ForProcessor final : public INodeProcessor }, m_node.children[1]->execute(exec_policy))); }()) { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 3); + } m_node.children[3]->execute(exec_until_jump); if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { diff --git a/src/language/node_processor/IfProcessor.hpp b/src/language/node_processor/IfProcessor.hpp index 0d7497eaf832d257d12dec349de4fe0c6b033ab9..0c5d356b7cb6e620a27437bc45bffbc9e9a13b57 100644 --- a/src/language/node_processor/IfProcessor.hpp +++ b/src/language/node_processor/IfProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class IfProcessor final : public INodeProcessor { @@ -20,29 +23,41 @@ class IfProcessor final : public INodeProcessor DataVariant execute(ExecutionPolicy& exec_policy) { - const bool is_true = static_cast<bool>(std::visit( // LCOV_EXCL_LINE (false negative) - [](const auto& value) -> bool { - using T = std::decay_t<decltype(value)>; - if constexpr (std::is_arithmetic_v<T>) { - return value; - } else { - return false; // LCOV_EXCL_LINE (unreachable: only there for compilation purpose) - } - }, - m_node.children[0]->execute(exec_policy))); - if (is_true) { - Assert(m_node.children[1] != nullptr); - m_node.children[1]->execute(exec_policy); - if (m_node.children[1]->m_symbol_table != m_node.m_symbol_table) - m_node.children[1]->m_symbol_table->clearValues(); + ResumingManager& resuming_manager = ResumingManager::getInstance(); + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + m_node.children[i_child]->execute(exec_policy); } else { - if (m_node.children.size() == 3) { - // else statement - Assert(m_node.children[2] != nullptr); - m_node.children[2]->execute(exec_policy); - if (m_node.children[2]->m_symbol_table != m_node.m_symbol_table) - m_node.children[2]->m_symbol_table->clearValues(); + const bool is_true = static_cast<bool>(std::visit( // LCOV_EXCL_LINE (false negative) + [](const auto& value) -> bool { + using T = std::decay_t<decltype(value)>; + if constexpr (std::is_arithmetic_v<T>) { + return value; + } else { + return false; // LCOV_EXCL_LINE (unreachable: only there for compilation purpose) + } + }, + m_node.children[0]->execute(exec_policy))); + if (is_true) { + Assert(m_node.children[1] != nullptr); + m_node.children[1]->execute(exec_policy); + if (m_node.children[1]->m_symbol_table != m_node.m_symbol_table) + m_node.children[1]->m_symbol_table->clearValues(); + + } else { + if (m_node.children.size() == 3) { + // else statement + Assert(m_node.children[2] != nullptr); + m_node.children[2]->execute(exec_policy); + if (m_node.children[2]->m_symbol_table != m_node.m_symbol_table) + m_node.children[2]->m_symbol_table->clearValues(); + } } } diff --git a/src/language/node_processor/WhileProcessor.hpp b/src/language/node_processor/WhileProcessor.hpp index 72c8f4186e81cfcc587e6cf55d6f503ee3f0f154..86912e126c8559303d6dae114612b07c39e0bca0 100644 --- a/src/language/node_processor/WhileProcessor.hpp +++ b/src/language/node_processor/WhileProcessor.hpp @@ -3,7 +3,10 @@ #include <language/ast/ASTNode.hpp> #include <language/node_processor/INodeProcessor.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/ResumingManager.hpp> class WhileProcessor final : public INodeProcessor { @@ -21,7 +24,8 @@ class WhileProcessor final : public INodeProcessor execute(ExecutionPolicy& exec_policy) { ExecutionPolicy exec_until_jump; - while ([&]() { + ResumingManager& resuming_manager = ResumingManager::getInstance(); + while (resuming_manager.isResuming() or [&]() { return static_cast<bool>(std::visit( [](const auto& value) -> bool { using T = std::decay_t<decltype(value)>; @@ -33,7 +37,17 @@ class WhileProcessor final : public INodeProcessor }, m_node.children[0]->execute(exec_policy))); }()) { + if (resuming_manager.isResuming()) [[unlikely]] { + const size_t checkpoint_id = resuming_manager.checkpointId(); + + const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); + const ASTCheckpoint& ast_checkpoint = ast_checkpoint_info.getASTCheckpoint(checkpoint_id); + + const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; + Assert(i_child == 1); + } m_node.children[1]->execute(exec_until_jump); + if (not exec_until_jump.exec()) { if (exec_until_jump.jumpType() == ExecutionPolicy::JumpType::break_jump) { break; diff --git a/src/language/utils/ASTCheckpoint.hpp b/src/language/utils/ASTCheckpoint.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2114083da18d97d9d0deb81b0da6d14c429899ae --- /dev/null +++ b/src/language/utils/ASTCheckpoint.hpp @@ -0,0 +1,36 @@ +#ifndef AST_CHECKPOINT_HPP +#define AST_CHECKPOINT_HPP + +#include <vector> + +class ASTNode; +class ASTCheckpoint +{ + private: + const ASTNode* const m_p_node; + const std::vector<size_t> m_ast_location; + + public: + const ASTNode& + node() const + { + return *m_p_node; + } + + const std::vector<size_t>& + getASTLocation() const + { + return m_ast_location; + } + + ASTCheckpoint(const std::vector<size_t>& ast_location, const ASTNode* const p_node) + : m_p_node{p_node}, m_ast_location{ast_location} + {} + + ASTCheckpoint(const ASTCheckpoint&) = default; + ASTCheckpoint(ASTCheckpoint&&) = default; + + ~ASTCheckpoint() = default; +}; + +#endif // AST_CHECKPOINT_HPP diff --git a/src/language/utils/ASTCheckpointsInfo.cpp b/src/language/utils/ASTCheckpointsInfo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..779290bba15c2c6de0f6d5d95cd5b2d0d6d4e823 --- /dev/null +++ b/src/language/utils/ASTCheckpointsInfo.cpp @@ -0,0 +1,53 @@ +#include <language/utils/ASTCheckpointsInfo.hpp> + +#include <language/PEGGrammar.hpp> +#include <language/ast/ASTNode.hpp> + +const ASTCheckpointsInfo* ASTCheckpointsInfo::m_checkpoints_info_instance = nullptr; + +void +ASTCheckpointsInfo::_findASTCheckpoint(std::vector<size_t>& location, const ASTNode& node) +{ + if (node.is_type<language::function_evaluation>()) { + const ASTNode& node_name = *node.children[0]; + if (node_name.is_type<language::name>() and (node_name.m_data_type == ASTNodeDataType::builtin_function_t) and + ((node_name.string() == "checkpoint") or (node_name.string() == "checkpoint_and_exit"))) { + m_ast_checkpoint_list.push_back(ASTCheckpoint{location, &node}); + return; + } + } + + if (node.children.size() > 0) { + location.push_back(0); + for (size_t i_node = 0; i_node < node.children.size(); ++i_node) { + location[location.size() - 1] = i_node; + this->_findASTCheckpoint(location, *node.children[i_node]); + } + location.pop_back(); + } +} + +ASTCheckpointsInfo::ASTCheckpointsInfo(const ASTNode& root_node) +{ + Assert(m_checkpoints_info_instance == nullptr, "Can only define one ASTCheckpointInfo"); + m_checkpoints_info_instance = this; + + Assert(root_node.is_root()); + + std::vector<size_t> location; + this->_findASTCheckpoint(location, root_node); + + Assert(location.size() == 0); +} + +const ASTCheckpointsInfo& +ASTCheckpointsInfo::getInstance() +{ + Assert(m_checkpoints_info_instance != nullptr, "ASTCheckpointInfo is not defined!"); + return *m_checkpoints_info_instance; +} + +ASTCheckpointsInfo::~ASTCheckpointsInfo() +{ + m_checkpoints_info_instance = nullptr; +} diff --git a/src/language/utils/ASTCheckpointsInfo.hpp b/src/language/utils/ASTCheckpointsInfo.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca10d3a58f50b5d47d4c412eb60cab6d49071206 --- /dev/null +++ b/src/language/utils/ASTCheckpointsInfo.hpp @@ -0,0 +1,53 @@ +#ifndef AST_CHECKPOINTS_INFO_HPP +#define AST_CHECKPOINTS_INFO_HPP + +#include <language/utils/ASTCheckpoint.hpp> +#include <utils/Exceptions.hpp> +#include <utils/PugsAssert.hpp> + +#include <string> +#include <vector> + +class ASTNode; + +class ASTCheckpointsInfo +{ + private: + static const ASTCheckpointsInfo* m_checkpoints_info_instance; + + std::vector<ASTCheckpoint> m_ast_checkpoint_list; + + void _findASTCheckpoint(std::vector<size_t>& location, const ASTNode& node); + + // The only place where the ASTCheckpointsInfo can be built + friend void parser(const std::string& filename); + + ASTCheckpointsInfo(const ASTNode& root_node); + + public: + size_t + getCheckpointId(const ASTNode& node) const + { + for (size_t i = 0; i < m_ast_checkpoint_list.size(); ++i) { + if (&m_ast_checkpoint_list[i].node() == &node) { + return i; + } + } + throw UnexpectedError("Could not find node"); + } + + const ASTCheckpoint& + getASTCheckpoint(size_t checkpoint_id) const + { + Assert(checkpoint_id < m_ast_checkpoint_list.size()); + return m_ast_checkpoint_list[checkpoint_id]; + } + + static const ASTCheckpointsInfo& getInstance(); + + ASTCheckpointsInfo() = delete; + + ~ASTCheckpointsInfo(); +}; + +#endif // AST_CHECKPOINTS_INFO_HPP diff --git a/src/language/utils/ASTExecutionInfo.cpp b/src/language/utils/ASTExecutionInfo.cpp index 6015b6a8f03625d436916224566725eaa7a18e2d..8a3131a6ef392e6f07c7b381096b69074b12ff82 100644 --- a/src/language/utils/ASTExecutionInfo.cpp +++ b/src/language/utils/ASTExecutionInfo.cpp @@ -2,24 +2,24 @@ #include <language/ast/ASTNode.hpp> -const ASTExecutionInfo* ASTExecutionInfo::m_current_execution_info = nullptr; +const ASTExecutionInfo* ASTExecutionInfo::m_execution_info_instance = nullptr; ASTExecutionInfo::ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_repository) : m_root_node{root_node}, m_module_repository{module_repository} { - Assert(m_current_execution_info == nullptr, "Can only define one ASTExecutionInfo"); + Assert(m_execution_info_instance == nullptr, "Can only define one ASTExecutionInfo"); - m_current_execution_info = this; + m_execution_info_instance = this; } const ASTExecutionInfo& -ASTExecutionInfo::current() +ASTExecutionInfo::getInstance() { - Assert(m_current_execution_info != nullptr, "ASTExecutionInfo is not defined!"); - return *m_current_execution_info; + Assert(m_execution_info_instance != nullptr, "ASTExecutionInfo is not defined!"); + return *m_execution_info_instance; } ASTExecutionInfo::~ASTExecutionInfo() { - m_current_execution_info = nullptr; + m_execution_info_instance = nullptr; } diff --git a/src/language/utils/ASTExecutionInfo.hpp b/src/language/utils/ASTExecutionInfo.hpp index da4b817428c2ea00e8c089e14de2049c4faae4af..1f389761ab1f14f6b2343fb195641d346a432cfe 100644 --- a/src/language/utils/ASTExecutionInfo.hpp +++ b/src/language/utils/ASTExecutionInfo.hpp @@ -8,7 +8,7 @@ class ASTNode; class ASTExecutionInfo { private: - static const ASTExecutionInfo* m_current_execution_info; + static const ASTExecutionInfo* m_execution_info_instance; const ASTNode& m_root_node; @@ -34,7 +34,7 @@ class ASTExecutionInfo return m_module_repository; } - static const ASTExecutionInfo& current(); + static const ASTExecutionInfo& getInstance(); ASTExecutionInfo() = delete; diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 33a41962e8f954ef214e0c6a50eefde1d8e29f7f..0455cd053bb57ed70a7c534dec1f172172a79963 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -8,6 +8,7 @@ add_library(PugsLanguageUtils AffectationRegisterForRnxn.cpp AffectationRegisterForString.cpp AffectationRegisterForZ.cpp + ASTCheckpointsInfo.cpp ASTDotPrinter.cpp ASTExecutionInfo.cpp ASTNodeDataType.cpp diff --git a/src/language/utils/SymbolTable.hpp b/src/language/utils/SymbolTable.hpp index 9f7bc8129587b8acd04be2297d604871018cf0a2..38351f7ecdf568e3482b1561152ae5a07276182f 100644 --- a/src/language/utils/SymbolTable.hpp +++ b/src/language/utils/SymbolTable.hpp @@ -24,6 +24,7 @@ class SymbolTable int32_t m_context_id; bool m_is_initialized{false}; + bool m_is_module_variable{false}; // variable that is created by a module (ex: cout,...) ASTNodeDataType m_data_type; DataVariant m_value; @@ -35,7 +36,7 @@ class SymbolTable return m_context_id != -1; } - const int32_t& + int32_t contextId() const { return m_context_id; @@ -53,7 +54,7 @@ class SymbolTable return m_value; } - const bool& + bool isInitialized() const { return m_is_initialized; @@ -65,6 +66,18 @@ class SymbolTable m_is_initialized = true; } + bool + isModuleVariable() const + { + return m_is_module_variable; + } + + void + setIsModuleVariable() + { + m_is_module_variable = true; + } + const ASTNodeDataType& dataType() const { @@ -98,7 +111,7 @@ class SymbolTable return os; } - Attributes& operator=(Attributes&&) = default; + Attributes& operator=(Attributes&&) = default; Attributes& operator=(const Attributes&) = default; Attributes(const TAO_PEGTL_NAMESPACE::position& position, int32_t context_id) @@ -138,7 +151,7 @@ class SymbolTable Symbol(const std::string& name, const Attributes& attributes) : m_name(name), m_attributes(attributes) {} - Symbol& operator=(Symbol&&) = default; + Symbol& operator=(Symbol&&) = default; Symbol& operator=(const Symbol&) = default; Symbol(const Symbol&) = default; @@ -179,7 +192,7 @@ class SymbolTable Context() : m_id{next_context_id++} {} Context& operator=(const Context&) = default; // clazy:exclude=function-args-by-value - Context& operator=(Context&&) = default; + Context& operator=(Context&&) = default; Context(const Context&) = default; Context(Context&&) = default; @@ -198,6 +211,24 @@ class SymbolTable std::shared_ptr<EmbedderTable<TypeDescriptor>> m_type_embedder_table; public: + const std::vector<Symbol>& + symbolList() const + { + return m_symbol_list; + } + + bool + hasParentTable() const + { + return m_parent_table.use_count() > 0; + } + + const std::shared_ptr<SymbolTable> + parentTable() const + { + return m_parent_table; + } + bool hasContext() const { diff --git a/src/main.cpp b/src/main.cpp index 7d0112c80715b3cd0b69ea91b874527134b6a2ed..a373bd2cd867b21b1d931449f12b64239ef03c22 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -9,11 +9,13 @@ #include <utils/GlobalVariableManager.hpp> #include <utils/PugsUtils.hpp> #include <utils/RandomEngine.hpp> +#include <utils/checkpointing/ResumingManager.hpp> int main(int argc, char* argv[]) { ExecutionStatManager::create(); + ResumingManager::create(); ParallelChecker::create(); std::string filename = initialize(argc, argv); @@ -42,6 +44,7 @@ main(int argc, char* argv[]) finalize(); ParallelChecker::destroy(); + ResumingManager::destroy(); int return_code = ExecutionStatManager::getInstance().exitCode(); ExecutionStatManager::destroy(); diff --git a/src/utils/CMakeLists.txt b/src/utils/CMakeLists.txt index b5a383c9c7d05088ef840d10b58fc57daa5cea04..f06fa68714b26ad75dff5b56c4981671a3b7db60 100644 --- a/src/utils/CMakeLists.txt +++ b/src/utils/CMakeLists.txt @@ -1,5 +1,7 @@ # ------------------- Source files -------------------- +add_subdirectory(checkpointing) + add_library( PugsUtils BuildInfo.cpp diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index 104e70cddc5cc29d927583adde0ffa53afe72e95..f1d4f488ec5a8ef190f0b4f9d5a01f9c67bd537b 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -13,6 +13,7 @@ #include <utils/RevisionInfo.hpp> #include <utils/SLEPcWrapper.hpp> #include <utils/SignalManager.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <utils/pugs_build_info.hpp> #include <rang.hpp> @@ -98,6 +99,9 @@ initialize(int& argc, char* argv[]) app.add_option("filename", filename, "pugs script file")->check(CLI::ExistingFile)->required(); + bool is_resuming = false; + app.add_flag("--resume", is_resuming, "Resume last checkpoint"); + app.set_version_flag("-v,--version", []() { ConsoleManager::init(true); std::stringstream os; @@ -162,6 +166,9 @@ initialize(int& argc, char* argv[]) CommunicatorManager::setSplitColor(mpi_split_color); } + ResumingManager::getInstance().setIsResuming(is_resuming); + ResumingManager::getInstance().setFilename(filename); + ExecutionStatManager::getInstance().setPrint(print_exec_stat); BacktraceManager::setShow(show_backtrace); ConsoleManager::setShowPreamble(show_preamble); diff --git a/src/utils/checkpointing/CMakeLists.txt b/src/utils/checkpointing/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2a0cdc827bf3bf4c866419fa6ea0d4a15058121 --- /dev/null +++ b/src/utils/checkpointing/CMakeLists.txt @@ -0,0 +1,18 @@ +# ------------------- Source files -------------------- + +add_library( + PugsCheckpointing + Checkpoint.cpp + Resume.cpp + ResumingManager.cpp + ResumingUtils.cpp) + +# Additional dependencies +add_dependencies(PugsCheckpointing + PugsLanguageAST + PugsUtils) + +target_link_libraries( + PugsCheckpointing + ${HIGHFIVE_TARGET} +) diff --git a/src/utils/checkpointing/Checkpoint.cpp b/src/utils/checkpointing/Checkpoint.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16ea2d830766d141cfb8c499e3454bdb581e5996 --- /dev/null +++ b/src/utils/checkpointing/Checkpoint.cpp @@ -0,0 +1,108 @@ +#include <utils/checkpointing/Checkpoint.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/ast/ASTExecutionStack.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <iostream> +#endif // PUGS_HAS_HDF5 + +#include <language/utils/ASTCheckpointsInfo.hpp> +#include <utils/Exceptions.hpp> +#include <utils/checkpointing/ResumingManager.hpp> + +#ifdef PUGS_HAS_HDF5 +void +checkpoint() +{ + auto create_props = HighFive::FileCreateProps{}; + create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); + + uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); + + const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; + + HighFive::File file("checkpoint.h5", file_openmode, create_props); + + HighFive::Group checkpoint = file.createGroup("checkpoint_" + std::to_string(checkpoint_number)); + + uint64_t checkpoint_id = + ASTCheckpointsInfo::getInstance().getCheckpointId((ASTExecutionStack::getInstance().currentNode())); + + checkpoint.createAttribute("checkpoint_id", checkpoint_id); + checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); + + std::shared_ptr<const SymbolTable> p_symbol_table = ASTExecutionStack::getInstance().currentNode().m_symbol_table; + auto symbol_table_group = checkpoint; + while (p_symbol_table.use_count() > 0) { + symbol_table_group = symbol_table_group.createGroup("symbol table"); + + const SymbolTable& symbol_table = *p_symbol_table; + + const auto& symbol_list = symbol_table.symbolList(); + + for (auto& symbol : symbol_list) { + switch (symbol.attributes().dataType()) { + case ASTNodeDataType::builtin_function_t: + case ASTNodeDataType::function_t: + case ASTNodeDataType::type_name_id_t: { + break; + } + default: { + if ((symbol.attributes().dataType() != ASTNodeDataType::builtin_function_t) and + (not symbol.attributes().isModuleVariable())) { + std::visit( + [&](auto&& data) { + using DataT = std::decay_t<decltype(data)>; + if constexpr (std::is_same_v<DataT, std::monostate>) { + } else if constexpr ((std::is_arithmetic_v<DataT>) or (std::is_same_v<DataT, std::string>) or + (is_tiny_vector_v<DataT>) or (is_tiny_matrix_v<DataT>)) { + symbol_table_group.createAttribute(symbol.name(), data); + } else if constexpr (is_std_vector_v<DataT>) { + using value_type = typename DataT::value_type; + if constexpr ((std::is_arithmetic_v<value_type>) or (std::is_same_v<value_type, std::string>) or + (is_tiny_vector_v<value_type>) or (is_tiny_matrix_v<value_type>)) { + symbol_table_group.createAttribute(symbol.name(), data); + } else { + throw NotImplementedError("datatype is not handled yet"); + } + } else { + throw NotImplementedError("datatype is not handled yet"); + } + }, + symbol.attributes().value()); + } + } + } + } + + p_symbol_table = symbol_table.parentTable(); + } + + if (file.exist("last_checkpoint")) { + file.unlink("last_checkpoint"); + } + file.createHardLink("last_checkpoint", checkpoint); + + if (file.hasAttribute("checkpoint_number")) { + file.deleteAttribute("checkpoint_number"); + } + file.createAttribute("checkpoint_number", checkpoint_number); + + ++checkpoint_number; +} + +#else // PUGS_HAS_HDF5 + +void +checkpoint() +{ + throw NormalError("checkpoint/resume mechanism requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/Checkpoint.hpp b/src/utils/checkpointing/Checkpoint.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5c4e910021128339d1e52d1ac9f471b64094fd3 --- /dev/null +++ b/src/utils/checkpointing/Checkpoint.hpp @@ -0,0 +1,6 @@ +#ifndef CHECKPOINT_HPP +#define CHECKPOINT_HPP + +void checkpoint(); + +#endif // CHECKPOINT_HPP diff --git a/src/utils/checkpointing/Resume.cpp b/src/utils/checkpointing/Resume.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56c3c283424a877e5cf69d8583bbec8d12b1c0e5 --- /dev/null +++ b/src/utils/checkpointing/Resume.cpp @@ -0,0 +1,229 @@ +#include <utils/checkpointing/Resume.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/ast/ASTExecutionStack.hpp> +#include <language/utils/SymbolTable.hpp> + +#include <iostream> +#endif // PUGS_HAS_HDF5 + +#include <language/utils/ASTCheckpointsInfo.hpp> +#include <utils/Exceptions.hpp> +#include <utils/checkpointing/ResumingManager.hpp> + +#ifdef PUGS_HAS_HDF5 + +void +resume() +{ + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/last_checkpoint"); + + HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table"); + + const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode(); + auto p_symbol_table = p_node->m_symbol_table; + + ResumingManager& resuming_manager = ResumingManager::getInstance(); + + resuming_manager.checkpointNumber() = file.getAttribute("checkpoint_number").read<uint64_t>(); + + std::cout << " * " + << "Using " << rang::fgB::green << "checkpoint" << rang::fg::reset << " number " + << resuming_manager.checkpointNumber()++ << '\n'; + + std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line " << rang::fgB::yellow + << p_node->begin().line << rang::fg::reset << " [checkpoint id " << rang::fgB::cyan + << resuming_manager.checkpointId() << rang::fg::reset << "]\n"; + + bool finished = true; + do { + finished = true; + + for (auto symbol_name : saved_symbol_table.listAttributeNames()) { + auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin()); + auto& attribute = p_symbol->attributes(); + switch (attribute.dataType()) { + case ASTNodeDataType::bool_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>(); + break; + } + case ASTNodeDataType::unsigned_int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>(); + break; + } + case ASTNodeDataType::int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>(); + break; + } + case ASTNodeDataType::double_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>(); + break; + } + case ASTNodeDataType::string_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>(); + break; + } + case ASTNodeDataType::vector_t: { + switch (attribute.dataType().dimension()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::matrix_t: { + // LCOV_EXCL_START + if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + switch (attribute.dataType().numberOfRows()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::tuple_t: { + switch (attribute.dataType().contentType()) { + case ASTNodeDataType::bool_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>(); + break; + } + case ASTNodeDataType::unsigned_int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>(); + break; + } + case ASTNodeDataType::int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>(); + break; + } + case ASTNodeDataType::double_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>(); + break; + } + case ASTNodeDataType::string_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>(); + break; + } + case ASTNodeDataType::vector_t: { + switch (attribute.dataType().contentType().dimension()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::matrix_t: { + // LCOV_EXCL_START + if (attribute.dataType().contentType().numberOfRows() != + attribute.dataType().contentType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + switch (attribute.dataType().contentType().numberOfRows()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + default: { + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType())); + } + } + break; + } + default: { + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType())); + } + } + } + + const bool symbol_table_has_parent = p_symbol_table->hasParentTable(); + const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table"); + + Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent)); + + if (symbol_table_has_parent and saved_symbol_table_has_parent) { + p_symbol_table = p_symbol_table->parentTable(); + saved_symbol_table = saved_symbol_table.getGroup("symbol table"); + + finished = false; + } + + } while (not finished); +} + +#else // PUGS_HAS_HDF5 + +void +resume() +{ + throw NormalError("checkpoint/resume mechanism requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/Resume.hpp b/src/utils/checkpointing/Resume.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9512539bc5b552f0d012c075b0f9ff7a1b69862c --- /dev/null +++ b/src/utils/checkpointing/Resume.hpp @@ -0,0 +1,6 @@ +#ifndef RESUME_HPP +#define RESUME_HPP + +void resume(); + +#endif // RESUME_HPP diff --git a/src/utils/checkpointing/ResumingManager.cpp b/src/utils/checkpointing/ResumingManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f4e8d68102980d61cb306e8a4b818b63eafe1e9a --- /dev/null +++ b/src/utils/checkpointing/ResumingManager.cpp @@ -0,0 +1,44 @@ +#include <utils/checkpointing/ResumingManager.hpp> + +#include <utils/HighFivePugsUtils.hpp> + +ResumingManager* ResumingManager::m_instance = nullptr; + +ResumingManager& +ResumingManager::getInstance() +{ + Assert(m_instance != nullptr, "instance was not created"); + return *m_instance; +} + +void +ResumingManager::create() +{ + Assert(m_instance == nullptr, "Resuming manager was already created"); + m_instance = new ResumingManager; +} + +void +ResumingManager::destroy() +{ + Assert(m_instance != nullptr, "Resuming manager was not created"); + delete m_instance; + m_instance = nullptr; +} + +size_t +ResumingManager::checkpointId() +{ +#ifdef PUGS_HAS_HDF5 + if (not m_checkpoint_id) { + HighFive::File file(m_filename, HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/last_checkpoint"); + + m_checkpoint_id = std::make_unique<size_t>(checkpoint.getAttribute("checkpoint_id").read<uint64_t>()); + } +#else // PUGS_HAS_HDF5 + m_checkpoint_id = std::make_unique<uint64_t>(0); +#endif // PUGS_HAS_HDF5 + return *m_checkpoint_id; +} diff --git a/src/utils/checkpointing/ResumingManager.hpp b/src/utils/checkpointing/ResumingManager.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a39c68d020453f6ba7ab875a4c6f183f724294c2 --- /dev/null +++ b/src/utils/checkpointing/ResumingManager.hpp @@ -0,0 +1,72 @@ +#ifndef RESUMING_MANAGER_HPP +#define RESUMING_MANAGER_HPP + +#include <utils/PugsAssert.hpp> + +#include <memory> + +class ResumingManager +{ + private: + bool m_is_resuming = false; + std::string m_filename; + // numbering of the checkpoints (during execution) + uint64_t m_checkpoint_number = 0; + // id of the checkpoint defined in the script file + std::unique_ptr<uint64_t> m_checkpoint_id; + size_t m_current_ast_level = 0; + + ResumingManager() = default; + + ResumingManager(ResumingManager&&) = delete; + ResumingManager(const ResumingManager&) = delete; + + ~ResumingManager() = default; + + static ResumingManager* m_instance; + + public: + static void create(); + static void destroy(); + static ResumingManager& getInstance(); + + uint64_t& + checkpointNumber() + { + return m_checkpoint_number; + } + + uint64_t checkpointId(); + + uint64_t& + currentASTLevel() + { + return m_current_ast_level; + } + + void + setIsResuming(const bool is_resuming) + { + m_is_resuming = is_resuming; + } + + bool + isResuming() const + { + return m_is_resuming; + } + + void + setFilename(const std::string& filename) + { + m_filename = filename; + } + + const std::string& + filename() const + { + return m_filename; + } +}; + +#endif // RESUMING_MANAGER_HPP diff --git a/src/utils/checkpointing/ResumingUtils.cpp b/src/utils/checkpointing/ResumingUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..af811b8079734b28052edf270182d278845ea3af --- /dev/null +++ b/src/utils/checkpointing/ResumingUtils.cpp @@ -0,0 +1,27 @@ +#include <utils/checkpointing/ResumingUtils.hpp> + +#include <utils/Exceptions.hpp> + +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +std::string +resumingDatafile(const std::string& filename) +{ + HighFive::File file(filename, HighFive::File::ReadOnly); +#warning change when checkpoint id can be specified + return file.getGroup("/last_checkpoint").getDataSet("data.pgs").read<std::string>(); +} + +#else // PUGS_HAS_HDF5 + +std::string +resumingDatafile(const std::string& filename) +{ + throw NormalError("Resuming requires HDF5"); +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/ResumingUtils.hpp b/src/utils/checkpointing/ResumingUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46cd27ed7056097d20af24f3c352971af85bc946 --- /dev/null +++ b/src/utils/checkpointing/ResumingUtils.hpp @@ -0,0 +1,8 @@ +#ifndef RESUMING_UTILS_HPP +#define RESUMING_UTILS_HPP + +#include <string> + +std::string resumingDatafile(const std::string& filename); + +#endif // RESUMING_UTILS_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index aaace485deb7ed83e66d80c2e75d9bce2cc9bc21..9f0b29d5f5b4817dca4468b5d90b7d420b47c555 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -234,6 +234,7 @@ target_link_libraries (unit_tests PugsScheme PugsOutput PugsUtils + PugsCheckpointing PugsDev Kokkos::kokkos ${PARMETIS_LIBRARIES} @@ -261,8 +262,9 @@ target_link_libraries (mpi_unit_tests PugsLanguageUtils PugsScheme PugsOutput - PugsDev PugsUtils + PugsCheckpointing + PugsDev PugsAlgebra PugsMesh Kokkos::kokkos diff --git a/tests/mpi_test_main.cpp b/tests/mpi_test_main.cpp index fb0da8569382769cf495dedceb42a91528076dd9..083dce9b142bb7a19e88d8e84255fc754b3ed70e 100644 --- a/tests/mpi_test_main.cpp +++ b/tests/mpi_test_main.cpp @@ -13,6 +13,7 @@ #include <utils/PETScWrapper.hpp> #include <utils/RandomEngine.hpp> #include <utils/Stringify.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <utils/pugs_config.hpp> #include <MeshDataBaseForTests.hpp> @@ -93,6 +94,8 @@ main(int argc, char* argv[]) // Disable outputs from tested classes to the standard output std::cout.setstate(std::ios::badbit); + ResumingManager::create(); + SynchronizerManager::create(); RandomEngine::create(); QuadratureManager::create(); @@ -118,6 +121,8 @@ main(int argc, char* argv[]) QuadratureManager::destroy(); RandomEngine::destroy(); SynchronizerManager::destroy(); + + ResumingManager::destroy(); } } diff --git a/tests/test_ASTModulesImporter.cpp b/tests/test_ASTModulesImporter.cpp index 24557c449ca495b92598ea5bd0140b55c381b15c..6156c571ed426632ef80ed5a1600aa8751c20586 100644 --- a/tests/test_ASTModulesImporter.cpp +++ b/tests/test_ASTModulesImporter.cpp @@ -19,7 +19,7 @@ test_ASTExecutionInfo(const ASTNode& root_node, const ModuleRepository& module_r REQUIRE(&root_node == &execution_info.rootNode()); REQUIRE(&module_repository == &execution_info.moduleRepository()); - REQUIRE(&ASTExecutionInfo::current() == &execution_info); + REQUIRE(&ASTExecutionInfo::getInstance() == &execution_info); } #define CHECK_AST(data, expected_output) \ diff --git a/tests/test_main.cpp b/tests/test_main.cpp index d8641d318f243eb624915882a92fc15d57a71346..fe5e8eea36d3973e17691a180bad2fd2f5820dcb 100644 --- a/tests/test_main.cpp +++ b/tests/test_main.cpp @@ -13,6 +13,7 @@ #include <utils/PETScWrapper.hpp> #include <utils/RandomEngine.hpp> #include <utils/SLEPcWrapper.hpp> +#include <utils/checkpointing/ResumingManager.hpp> #include <MeshDataBaseForTests.hpp> @@ -52,6 +53,8 @@ main(int argc, char* argv[]) // Disable outputs from tested classes to the standard output std::cout.setstate(std::ios::badbit); + ResumingManager::create(); + SynchronizerManager::create(); RandomEngine::create(); QuadratureManager::create(); @@ -77,6 +80,8 @@ main(int argc, char* argv[]) QuadratureManager::destroy(); RandomEngine::destroy(); SynchronizerManager::destroy(); + + ResumingManager::destroy(); } }