From 41deef32149d9b08ceceb58ff41f10e003b19cea Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 6 Nov 2024 22:57:17 +0100 Subject: [PATCH] Add tests for language's node processors for checkpointing --- .../node_processor/ASTNodeListProcessor.hpp | 3 + .../node_processor/DoWhileProcessor.hpp | 3 + src/language/node_processor/ForProcessor.hpp | 3 + src/language/node_processor/IfProcessor.hpp | 4 +- .../node_processor/WhileProcessor.hpp | 3 + tests/CMakeLists.txt | 1 + tests/test_BuiltinFunctionProcessor.cpp | 9 + tests/test_ConcatExpressionProcessor.cpp | 8 + tests/test_DoWhileProcessor.cpp | 7 + tests/test_ForProcessor.cpp | 7 + tests/test_FunctionProcessor.cpp | 9 + tests/test_IfProcessor.cpp | 7 + tests/test_WhileProcessor.cpp | 7 + tests/test_checkpointing_Resume.cpp | 1020 +++++++++++++++++ 14 files changed, 1090 insertions(+), 1 deletion(-) create mode 100644 tests/test_checkpointing_Resume.cpp diff --git a/src/language/node_processor/ASTNodeListProcessor.hpp b/src/language/node_processor/ASTNodeListProcessor.hpp index 23b193275..4e34dcbe9 100644 --- a/src/language/node_processor/ASTNodeListProcessor.hpp +++ b/src/language/node_processor/ASTNodeListProcessor.hpp @@ -8,6 +8,7 @@ #include <language/utils/SymbolTable.hpp> #include <utils/checkpointing/Checkpoint.hpp> #include <utils/checkpointing/ResumingManager.hpp> +#include <utils/pugs_config.hpp> class ASTNodeListProcessor final : public INodeProcessor { @@ -26,6 +27,7 @@ class ASTNodeListProcessor final : public INodeProcessor { ResumingManager& resuming_manager = ResumingManager::getInstance(); if (resuming_manager.isResuming()) { +#ifdef PUGS_HAS_HDF5 const size_t checkpoint_id = resuming_manager.checkpointId(); const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); @@ -35,6 +37,7 @@ class ASTNodeListProcessor final : public INodeProcessor i_child < m_node.children.size(); ++i_child) { m_node.children[i_child]->execute(exec_policy); } +#endif // PUGS_HAS_HDF5 } else { for (auto&& child : m_node.children) { child->execute(exec_policy); diff --git a/src/language/node_processor/DoWhileProcessor.hpp b/src/language/node_processor/DoWhileProcessor.hpp index 76fdfa22b..352229edf 100644 --- a/src/language/node_processor/DoWhileProcessor.hpp +++ b/src/language/node_processor/DoWhileProcessor.hpp @@ -7,6 +7,7 @@ #include <language/utils/SymbolTable.hpp> #include <utils/checkpointing/Checkpoint.hpp> #include <utils/checkpointing/ResumingManager.hpp> +#include <utils/pugs_config.hpp> class DoWhileProcessor final : public INodeProcessor { @@ -28,6 +29,7 @@ class DoWhileProcessor final : public INodeProcessor ResumingManager& resuming_manager = ResumingManager::getInstance(); do { if (resuming_manager.isResuming()) { +#ifdef PUGS_HAS_HDF5 const size_t checkpoint_id = resuming_manager.checkpointId(); const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); @@ -35,6 +37,7 @@ class DoWhileProcessor final : public INodeProcessor const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; Assert(i_child == 0); +#endif // PUGS_HAS_HDF5 }; m_node.children[0]->execute(exec_until_jump); if (not exec_until_jump.exec()) { diff --git a/src/language/node_processor/ForProcessor.hpp b/src/language/node_processor/ForProcessor.hpp index c68893c9f..caf202bdb 100644 --- a/src/language/node_processor/ForProcessor.hpp +++ b/src/language/node_processor/ForProcessor.hpp @@ -7,6 +7,7 @@ #include <language/utils/SymbolTable.hpp> #include <utils/checkpointing/Checkpoint.hpp> #include <utils/checkpointing/ResumingManager.hpp> +#include <utils/pugs_config.hpp> class ForProcessor final : public INodeProcessor { @@ -42,6 +43,7 @@ class ForProcessor final : public INodeProcessor m_node.children[1]->execute(exec_policy))); }()) { if (resuming_manager.isResuming()) { +#ifdef PUGS_HAS_HDF5 const size_t checkpoint_id = resuming_manager.checkpointId(); const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); @@ -49,6 +51,7 @@ class ForProcessor final : public INodeProcessor const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; Assert(i_child == 3); +#endif // PUGS_HAS_HDF5 } m_node.children[3]->execute(exec_until_jump); if (not exec_until_jump.exec()) { diff --git a/src/language/node_processor/IfProcessor.hpp b/src/language/node_processor/IfProcessor.hpp index 7b49f46fa..bb41dc165 100644 --- a/src/language/node_processor/IfProcessor.hpp +++ b/src/language/node_processor/IfProcessor.hpp @@ -7,6 +7,7 @@ #include <language/utils/SymbolTable.hpp> #include <utils/checkpointing/Checkpoint.hpp> #include <utils/checkpointing/ResumingManager.hpp> +#include <utils/pugs_config.hpp> class IfProcessor final : public INodeProcessor { @@ -25,6 +26,7 @@ class IfProcessor final : public INodeProcessor { ResumingManager& resuming_manager = ResumingManager::getInstance(); if (resuming_manager.isResuming()) { +#ifdef PUGS_HAS_HDF5 const size_t checkpoint_id = resuming_manager.checkpointId(); const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); @@ -32,7 +34,7 @@ class IfProcessor final : public INodeProcessor const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; m_node.children[i_child]->execute(exec_policy); - +#endif // PUGS_HAS_HDF5 } else { const bool is_true = static_cast<bool>(std::visit( // LCOV_EXCL_LINE (false negative) [](const auto& value) -> bool { diff --git a/src/language/node_processor/WhileProcessor.hpp b/src/language/node_processor/WhileProcessor.hpp index 994bcdd24..d36360573 100644 --- a/src/language/node_processor/WhileProcessor.hpp +++ b/src/language/node_processor/WhileProcessor.hpp @@ -7,6 +7,7 @@ #include <language/utils/SymbolTable.hpp> #include <utils/checkpointing/Checkpoint.hpp> #include <utils/checkpointing/ResumingManager.hpp> +#include <utils/pugs_config.hpp> class WhileProcessor final : public INodeProcessor { @@ -38,6 +39,7 @@ class WhileProcessor final : public INodeProcessor m_node.children[0]->execute(exec_policy))); }()) { if (resuming_manager.isResuming()) { +#ifdef PUGS_HAS_HDF5 const size_t checkpoint_id = resuming_manager.checkpointId(); const ASTCheckpointsInfo& ast_checkpoint_info = ASTCheckpointsInfo::getInstance(); @@ -45,6 +47,7 @@ class WhileProcessor final : public INodeProcessor const size_t i_child = ast_checkpoint.getASTLocation()[resuming_manager.currentASTLevel()++]; Assert(i_child == 1); +#endif // PUGS_HAS_HDF5 } m_node.children[1]->execute(exec_until_jump); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bfde50850..0a6d69d39 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -172,6 +172,7 @@ add_executable (unit_tests test_checkpointing_Checkpoint.cpp test_checkpointing_PrintCheckpointInfo.cpp test_checkpointing_PrintScriptFrom.cpp + test_checkpointing_Resume.cpp test_checkpointing_ResumingManager.cpp test_checkpointing_ResumingUtils.cpp test_checkpointing_SetResumeFrom.cpp diff --git a/tests/test_BuiltinFunctionProcessor.cpp b/tests/test_BuiltinFunctionProcessor.cpp index 258a3a69f..398308097 100644 --- a/tests/test_BuiltinFunctionProcessor.cpp +++ b/tests/test_BuiltinFunctionProcessor.cpp @@ -10,6 +10,7 @@ #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/modules/MathModule.hpp> +#include <language/node_processor/BuiltinFunctionProcessor.hpp> #include <test_BuiltinFunctionRegister.hpp> @@ -638,4 +639,12 @@ let x:R, x = tuple_ZtoR(3); CHECK_BUILTIN_FUNCTION_EVALUATION_RESULT(data, "x", 0.5 * 3); } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(BuiltinFunctionExpressionProcessor{nullptr}.type() == + INodeProcessor::Type::builtin_function_expression_processor); + REQUIRE(BuiltinFunctionProcessor{node}.type() == INodeProcessor::Type::builtin_function_processor); + } } diff --git a/tests/test_ConcatExpressionProcessor.cpp b/tests/test_ConcatExpressionProcessor.cpp index ca02f8030..ed2780397 100644 --- a/tests/test_ConcatExpressionProcessor.cpp +++ b/tests/test_ConcatExpressionProcessor.cpp @@ -10,6 +10,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/ConcatExpressionProcessor.hpp> #include <language/utils/ASTPrinter.hpp> #include <utils/Demangle.hpp> #include <utils/Stringify.hpp> @@ -205,4 +206,11 @@ TEST_CASE("ConcatExpressionProcessor", "[language]") CHECK_CONCAT_EXPRESSION_RESULT(R"(let x:R^3x3, x = [[1,2,3],[4,5,6],[7,8,9]]; let s:string, s = "_foo"; s = x+s;)", "s", os.str()); } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(ConcatExpressionProcessor<std::string, std::string>{node}.type() == + INodeProcessor::Type::concat_expression_processor); + } } diff --git a/tests/test_DoWhileProcessor.cpp b/tests/test_DoWhileProcessor.cpp index 01c170247..fffd4a7f4 100644 --- a/tests/test_DoWhileProcessor.cpp +++ b/tests/test_DoWhileProcessor.cpp @@ -10,6 +10,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/DoWhileProcessor.hpp> #include <language/utils/ASTPrinter.hpp> #include <utils/Demangle.hpp> @@ -147,4 +148,10 @@ do { CHECK_DO_WHILE_PROCESSOR_THROWS_WITH(data, "invalid implicit conversion: Z -> B"); } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(DoWhileProcessor{node}.type() == INodeProcessor::Type::do_while_processor); + } } diff --git a/tests/test_ForProcessor.cpp b/tests/test_ForProcessor.cpp index b830dfd63..addc9c0b5 100644 --- a/tests/test_ForProcessor.cpp +++ b/tests/test_ForProcessor.cpp @@ -10,6 +10,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/ForProcessor.hpp> #include <language/utils/ASTPrinter.hpp> #include <utils/Demangle.hpp> @@ -122,4 +123,10 @@ for(let l:N, l=0; l; ++l) { CHECK_FOR_PROCESSOR_THROWS_WITH(data, "invalid implicit conversion: N -> B"); } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(ForProcessor{node}.type() == INodeProcessor::Type::for_processor); + } } diff --git a/tests/test_FunctionProcessor.cpp b/tests/test_FunctionProcessor.cpp index ede3bd01b..aa8c75e99 100644 --- a/tests/test_FunctionProcessor.cpp +++ b/tests/test_FunctionProcessor.cpp @@ -9,6 +9,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/FunctionProcessor.hpp> #include <utils/Demangle.hpp> #include <utils/Stringify.hpp> @@ -1954,4 +1955,12 @@ f(1, 2); } } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE((FunctionProcessor{node, SymbolTable::Context{}}.type() == INodeProcessor::Type::function_processor)); + REQUIRE((FunctionExpressionProcessor<double, double>{node}.type() == + INodeProcessor::Type::function_expression_processor)); + } } diff --git a/tests/test_IfProcessor.cpp b/tests/test_IfProcessor.cpp index 38b687c79..0d71675b4 100644 --- a/tests/test_IfProcessor.cpp +++ b/tests/test_IfProcessor.cpp @@ -9,6 +9,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/IfProcessor.hpp> #include <utils/Demangle.hpp> #include <pegtl/string_input.hpp> @@ -157,4 +158,10 @@ if (1.2) { CHECK_IF_PROCESSOR_THROWS_WITH(data, "invalid implicit conversion: R -> B"); } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(IfProcessor{node}.type() == INodeProcessor::Type::if_processor); + } } diff --git a/tests/test_WhileProcessor.cpp b/tests/test_WhileProcessor.cpp index 7641f9ab2..8b4e34527 100644 --- a/tests/test_WhileProcessor.cpp +++ b/tests/test_WhileProcessor.cpp @@ -10,6 +10,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/node_processor/WhileProcessor.hpp> #include <utils/Demangle.hpp> #include <pegtl/string_input.hpp> @@ -142,4 +143,10 @@ while(1) { CHECK_WHILE_PROCESSOR_THROWS_WITH(data, "invalid implicit conversion: Z -> B"); } } + + SECTION("expression type") + { + ASTNode node; + REQUIRE(WhileProcessor{node}.type() == INodeProcessor::Type::while_processor); + } } diff --git a/tests/test_checkpointing_Resume.cpp b/tests/test_checkpointing_Resume.cpp new file mode 100644 index 000000000..82ea9202b --- /dev/null +++ b/tests/test_checkpointing_Resume.cpp @@ -0,0 +1,1020 @@ +#include <catch2/catch_approx.hpp> +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_predicate.hpp> + +#include <utils/checkpointing/Resume.hpp> +#include <utils/pugs_config.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <MeshDataBaseForTests.hpp> +#include <dev/ParallelChecker.hpp> +#include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTExecutionStack.hpp> +#include <language/ast/ASTModulesImporter.hpp> +#include <language/ast/ASTNodeDataTypeBuilder.hpp> +#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp> +#include <language/ast/ASTNodeExpressionBuilder.hpp> +#include <language/ast/ASTNodeTypeCleaner.hpp> +#include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/modules/MathModule.hpp> +#include <language/utils/ASTCheckpointsInfo.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> +#include <utils/ExecutionStatManager.hpp> +#include <utils/checkpointing/SetResumeFrom.hpp> + +class ASTCheckpointsInfoTester +{ + private: + ASTCheckpointsInfo m_ast_checkpoint_info; + + public: + ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {} + ~ASTCheckpointsInfoTester() = default; +}; + +#define RUN_AST(data) \ + { \ + ExecutionStatManager::create(); \ + ParallelChecker::create(); \ + CheckpointResumeRepository::create(); \ + \ + TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ + auto ast = ASTBuilder::build(input); \ + \ + ASTModulesImporter{*ast}; \ + ASTNodeTypeCleaner<language::import_instruction>{*ast}; \ + \ + ASTSymbolTableBuilder{*ast}; \ + ASTNodeDataTypeBuilder{*ast}; \ + \ + ASTNodeDeclarationToAffectationConverter{*ast}; \ + ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ + ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \ + \ + ASTNodeExpressionBuilder{*ast}; \ + ExecutionPolicy exec_policy; \ + ASTExecutionStack::create(); \ + ASTCheckpointsInfoTester ast_cp_info_tester{*ast}; \ + ast->execute(exec_policy); \ + ASTExecutionStack::destroy(); \ + ast->m_symbol_table->clearValues(); \ + \ + CheckpointResumeRepository::destroy(); \ + ParallelChecker::destroy(); \ + ExecutionStatManager::destroy(); \ + ast->m_symbol_table->clearValues(); \ + } + +#endif // PUGS_HAS_HDF5 + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("checkpointing_Resume", "[utils/checkpointing]") +{ +#ifdef PUGS_HAS_HDF5 + + SECTION("general") + { + auto frobeniusNorm = [](const auto& A) { + using A_T = std::decay_t<decltype(A)>; + static_assert(is_tiny_matrix_v<A_T>); + return std::sqrt(trace(transpose(A) * A)); + }; + + using R1 = TinyVector<1>; + using R2 = TinyVector<2>; + using R3 = TinyVector<3>; + + using R11 = TinyMatrix<1>; + using R22 = TinyMatrix<2>; + using R33 = TinyMatrix<3>; + + MeshDataBaseForTests::destroy(); + GlobalVariableManager::instance().setMeshId(0); + GlobalVariableManager::instance().setConnectivityId(0); + + ResumingManager::destroy(); + ResumingManager::create(); + + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + ResumingManager::getInstance().setFilename(filename); + } + + const std::string filename = ResumingManager::getInstance().filename(); + + std::string data = R"( +import math; +import mesh; +import scheme; + +let f:R*R^3 -> R^3, (a, v) -> a * v; + +let alpha:R, alpha = 3.2; +let u1:R^1, u1 = [0.3]; +let u2:R^2, u2 = [0.3, 1.2]; +let u3:R^3, u3 = [1, 2, 3]; + +let A1:R^1x1, A1 = [[0.7]]; +let A2:R^2x2, A2 = [[1.4, 2.1], [0.6, 3]]; +let A3:R^3x3, A3 = [[1.1, 2.2, 3.3], [0.1, 0.2, 0.3], [1.6, 1.2, 1.4]]; + +let m2d:mesh, m2d = cartesianMesh(0, [1,1], (10,10)); + +let b_tuple:(B), b_tuple = (true, false, true); +let n_tuple:(N), n_tuple = (1, 2, 3, 4); +let z_tuple:(Z), z_tuple = (1, -2, 3, -4); +let r_tuple:(R), r_tuple = (1.2, -2.4, 3.1, -4.3); +let s_tuple:(string), s_tuple = ("foo", "bar"); + +let r1_tuple:(R^1), r1_tuple = ([1], [2]); +let r2_tuple:(R^2), r2_tuple = ([1.2, 3], [2.3, 4], [3.2, 1.4]); +let r3_tuple:(R^3), r3_tuple = ([1.2, 0.2, 3], [2.3, -1, 4], [3.2, 2.1, 1.4]); + +let r11_tuple:(R^1x1), r11_tuple = ([[1.3]], [[2.4]]); +let r22_tuple:(R^2x2), r22_tuple = ([[1.2, 3], [2.3, 4]], [[3.2, 1.4], [1.3, 5.2]]); +let r33_tuple:(R^3x3), r33_tuple = ([[1.2, 0.2, 3], [2.3, -1, 4], [3.2, 2.1, 1.4]]); + +let m1d:mesh, m1d = cartesianMesh([0], [1], 10); + +let m3d:mesh, m3d = cartesianMesh(0, [1,1,1], (6,6,6)); + +let b:B, b = false; +let z:Z, z = -2; +let s:string, s = "foobar"; + +for(let i:N, i=0; i<3; ++i) { + checkpoint(); + + s = "foobar_"+i; + z += 1; + b = not b; + + let g:R -> R^3, x -> [x, 1.2*x, 3]; + u3 = f(alpha,u3); + + checkpoint(); +} +)"; + + const size_t initial_mesh_id = GlobalVariableManager::instance().getMeshId(); + const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId(); + + RUN_AST(data); + + GlobalVariableManager::instance().setMeshId(initial_mesh_id); + GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5"); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2); + + HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); + REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2); + REQUIRE(l2Norm(symbol_table1.getAttribute("u1").read<R1>() - R1{0.3}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u2").read<R2>() - R2{0.3, 1.2}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u3").read<R3>() - R3{32.768, 65.536, 98.304}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(symbol_table1.getAttribute("b").read<bool>() == true); + REQUIRE(symbol_table1.getAttribute("z").read<int64_t>() == 1); + REQUIRE(symbol_table1.getAttribute("s").read<std::string>() == "foobar_2"); + REQUIRE(symbol_table1.getAttribute("A1").read<R11>()(0, 0) == Catch::Approx(0.7).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A2").read<R22>() - R22{1.4, 2.1, 0.6, 3}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A3").read<R33>() - + R33{1.1, 2.2, 3.3, 0.1, 0.2, 0.3, 1.6, 1.2, 1.4}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(symbol_table1.getAttribute("b_tuple").read<std::vector<bool>>() == std::vector<bool>{true, false, true}); + REQUIRE(symbol_table1.getAttribute("n_tuple").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4}); + REQUIRE(symbol_table1.getAttribute("z_tuple").read<std::vector<int64_t>>() == std::vector<int64_t>{1, -2, 3, -4}); + REQUIRE(symbol_table1.getAttribute("r_tuple").read<std::vector<double>>() == + std::vector<double>{1.2, -2.4, 3.1, -4.3}); + REQUIRE(symbol_table1.getAttribute("s_tuple").read<std::vector<std::string>>() == + std::vector<std::string>{"foo", "bar"}); + + REQUIRE(symbol_table1.getAttribute("r1_tuple").read<std::vector<R1>>() == std::vector{R1{1}, R1{2}}); + REQUIRE(symbol_table1.getAttribute("r2_tuple").read<std::vector<R2>>() == + std::vector{R2{1.2, 3}, R2{2.3, 4}, R2{3.2, 1.4}}); + REQUIRE(symbol_table1.getAttribute("r3_tuple").read<std::vector<R3>>() == + std::vector{R3{1.2, 0.2, 3}, R3{2.3, -1, 4}, R3{3.2, 2.1, 1.4}}); + + REQUIRE(symbol_table1.getAttribute("r11_tuple").read<std::vector<R11>>() == std::vector{R11{1.3}, R11{2.4}}); + REQUIRE(symbol_table1.getAttribute("r22_tuple").read<std::vector<R22>>() == + std::vector{R22{1.2, 3, 2.3, 4}, R22{3.2, 1.4, 1.3, 5.2}}); + REQUIRE(symbol_table1.getAttribute("r33_tuple").read<std::vector<R33>>() == + std::vector{R33{1.2, 0.2, 3, 2.3, -1, 4, 3.2, 2.1, 1.4}}); + + HighFive::Group embedded1 = symbol_table1.getGroup("embedded"); + + HighFive::Group m1d = embedded1.getGroup("m1d"); + REQUIRE(m1d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m1d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + + HighFive::Group m2d = embedded1.getGroup("m2d"); + REQUIRE(m2d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m2d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + + HighFive::Group m3d = embedded1.getGroup("m3d"); + REQUIRE(m3d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m3d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group global_variables = singleton.getGroup("global_variables"); + REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == + initial_connectivity_id + 3 * (1 + (parallel::size() > 1))); + REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == + initial_mesh_id + 3 * (1 + (parallel::size() > 1))); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group connectivity = checkpoint.getGroup("connectivity"); + HighFive::Group connectivity0 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (parallel::size() > 1))); + REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity1 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(connectivity1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity2 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(connectivity2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group mesh = checkpoint.getGroup("mesh"); + HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id + (parallel::size() > 1))); + REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(mesh1.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(mesh2.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group functions = checkpoint.getGroup("functions"); + HighFive::Group f = functions.getGroup("f"); + REQUIRE(f.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1); + HighFive::Group g = functions.getGroup("g"); + REQUIRE(g.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0); + } + + parallel::barrier(); + + setResumeFrom(filename, 3); + + parallel::barrier(); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 3); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_3"); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 1); + + HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); + REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2); + REQUIRE(l2Norm(symbol_table1.getAttribute("u1").read<R1>() - R1{0.3}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u2").read<R2>() - R2{0.3, 1.2}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u3").read<R3>() - R3{10.24, 20.48, 30.72}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(symbol_table1.getAttribute("b").read<bool>() == false); + REQUIRE(symbol_table1.getAttribute("z").read<int64_t>() == 0); + REQUIRE(symbol_table1.getAttribute("s").read<std::string>() == "foobar_1"); + REQUIRE(symbol_table1.getAttribute("A1").read<R11>()(0, 0) == Catch::Approx(0.7).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A2").read<R22>() - R22{1.4, 2.1, 0.6, 3}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A3").read<R33>() - + R33{1.1, 2.2, 3.3, 0.1, 0.2, 0.3, 1.6, 1.2, 1.4}) == Catch::Approx(0).margin(1E-12)); + + REQUIRE(symbol_table1.getAttribute("b_tuple").read<std::vector<bool>>() == std::vector<bool>{true, false, true}); + REQUIRE(symbol_table1.getAttribute("n_tuple").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4}); + REQUIRE(symbol_table1.getAttribute("z_tuple").read<std::vector<int64_t>>() == std::vector<int64_t>{1, -2, 3, -4}); + REQUIRE(symbol_table1.getAttribute("r_tuple").read<std::vector<double>>() == + std::vector<double>{1.2, -2.4, 3.1, -4.3}); + REQUIRE(symbol_table1.getAttribute("s_tuple").read<std::vector<std::string>>() == + std::vector<std::string>{"foo", "bar"}); + + REQUIRE(symbol_table1.getAttribute("r1_tuple").read<std::vector<R1>>() == std::vector{R1{1}, R1{2}}); + REQUIRE(symbol_table1.getAttribute("r2_tuple").read<std::vector<R2>>() == + std::vector{R2{1.2, 3}, R2{2.3, 4}, R2{3.2, 1.4}}); + REQUIRE(symbol_table1.getAttribute("r3_tuple").read<std::vector<R3>>() == + std::vector{R3{1.2, 0.2, 3}, R3{2.3, -1, 4}, R3{3.2, 2.1, 1.4}}); + + REQUIRE(symbol_table1.getAttribute("r11_tuple").read<std::vector<R11>>() == std::vector{R11{1.3}, R11{2.4}}); + REQUIRE(symbol_table1.getAttribute("r22_tuple").read<std::vector<R22>>() == + std::vector{R22{1.2, 3, 2.3, 4}, R22{3.2, 1.4, 1.3, 5.2}}); + REQUIRE(symbol_table1.getAttribute("r33_tuple").read<std::vector<R33>>() == + std::vector{R33{1.2, 0.2, 3, 2.3, -1, 4, 3.2, 2.1, 1.4}}); + + HighFive::Group embedded1 = symbol_table1.getGroup("embedded"); + + HighFive::Group m1d = embedded1.getGroup("m1d"); + REQUIRE(m1d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m1d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + + HighFive::Group m2d = embedded1.getGroup("m2d"); + REQUIRE(m2d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m2d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + + HighFive::Group m3d = embedded1.getGroup("m3d"); + REQUIRE(m3d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m3d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group global_variables = singleton.getGroup("global_variables"); + REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == + initial_connectivity_id + 3 * (1 + (parallel::size() > 1))); + REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == + initial_mesh_id + 3 * (1 + (parallel::size() > 1))); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group connectivity = checkpoint.getGroup("connectivity"); + HighFive::Group connectivity0 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (parallel::size() > 1))); + REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity1 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(connectivity1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity2 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(connectivity2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group mesh = checkpoint.getGroup("mesh"); + HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id + (parallel::size() > 1))); + REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(mesh1.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(mesh2.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group functions = checkpoint.getGroup("functions"); + HighFive::Group f = functions.getGroup("f"); + REQUIRE(f.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1); + HighFive::Group g = functions.getGroup("g"); + REQUIRE(g.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0); + } + + ResumingManager::destroy(); + ResumingManager::create(); + ResumingManager::getInstance().setFilename(filename); + ResumingManager::getInstance().setIsResuming(true); + GlobalVariableManager::instance().setMeshId(initial_mesh_id); + GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id); + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5"); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2); + + HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); + REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2); + REQUIRE(l2Norm(symbol_table1.getAttribute("u1").read<R1>() - R1{0.3}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u2").read<R2>() - R2{0.3, 1.2}) == Catch::Approx(0).margin(1E-12)); + REQUIRE(l2Norm(symbol_table1.getAttribute("u3").read<R3>() - R3{32.768, 65.536, 98.304}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(symbol_table1.getAttribute("b").read<bool>() == true); + REQUIRE(symbol_table1.getAttribute("z").read<int64_t>() == 1); + REQUIRE(symbol_table1.getAttribute("s").read<std::string>() == "foobar_2"); + REQUIRE(symbol_table1.getAttribute("A1").read<R11>()(0, 0) == Catch::Approx(0.7).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A2").read<R22>() - R22{1.4, 2.1, 0.6, 3}) == + Catch::Approx(0).margin(1E-12)); + REQUIRE(frobeniusNorm(symbol_table1.getAttribute("A3").read<R33>() - + R33{1.1, 2.2, 3.3, 0.1, 0.2, 0.3, 1.6, 1.2, 1.4}) == Catch::Approx(0).margin(1E-12)); + + REQUIRE(symbol_table1.getAttribute("b_tuple").read<std::vector<bool>>() == std::vector<bool>{true, false, true}); + REQUIRE(symbol_table1.getAttribute("n_tuple").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4}); + REQUIRE(symbol_table1.getAttribute("z_tuple").read<std::vector<int64_t>>() == std::vector<int64_t>{1, -2, 3, -4}); + REQUIRE(symbol_table1.getAttribute("r_tuple").read<std::vector<double>>() == + std::vector<double>{1.2, -2.4, 3.1, -4.3}); + REQUIRE(symbol_table1.getAttribute("s_tuple").read<std::vector<std::string>>() == + std::vector<std::string>{"foo", "bar"}); + + REQUIRE(symbol_table1.getAttribute("r1_tuple").read<std::vector<R1>>() == std::vector{R1{1}, R1{2}}); + REQUIRE(symbol_table1.getAttribute("r2_tuple").read<std::vector<R2>>() == + std::vector{R2{1.2, 3}, R2{2.3, 4}, R2{3.2, 1.4}}); + REQUIRE(symbol_table1.getAttribute("r3_tuple").read<std::vector<R3>>() == + std::vector{R3{1.2, 0.2, 3}, R3{2.3, -1, 4}, R3{3.2, 2.1, 1.4}}); + + REQUIRE(symbol_table1.getAttribute("r11_tuple").read<std::vector<R11>>() == std::vector{R11{1.3}, R11{2.4}}); + REQUIRE(symbol_table1.getAttribute("r22_tuple").read<std::vector<R22>>() == + std::vector{R22{1.2, 3, 2.3, 4}, R22{3.2, 1.4, 1.3, 5.2}}); + REQUIRE(symbol_table1.getAttribute("r33_tuple").read<std::vector<R33>>() == + std::vector{R33{1.2, 0.2, 3, 2.3, -1, 4, 3.2, 2.1, 1.4}}); + + HighFive::Group embedded1 = symbol_table1.getGroup("embedded"); + + HighFive::Group m1d = embedded1.getGroup("m1d"); + REQUIRE(m1d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m1d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + + HighFive::Group m2d = embedded1.getGroup("m2d"); + REQUIRE(m2d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m2d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + + HighFive::Group m3d = embedded1.getGroup("m3d"); + REQUIRE(m3d.getAttribute("type").read<std::string>() == "mesh"); + REQUIRE(m3d.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group global_variables = singleton.getGroup("global_variables"); + REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == + initial_connectivity_id + 3 * (1 + (parallel::size() > 1))); + REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == + initial_mesh_id + 3 * (1 + (parallel::size() > 1))); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 2); + + HighFive::Group connectivity = checkpoint.getGroup("connectivity"); + HighFive::Group connectivity0 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (parallel::size() > 1))); + REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity1 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(connectivity1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group connectivity2 = + connectivity.getGroup(std::to_string(initial_connectivity_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(connectivity2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "unstructured"); + + HighFive::Group mesh = checkpoint.getGroup("mesh"); + HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id + (parallel::size() > 1))); + REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2); + REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1)); + REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + (1 + 2 * (parallel::size() > 1)))); + REQUIRE(mesh1.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("dimension").read<uint64_t>() == 1); + REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + (1 + 2 * (parallel::size() > 1))); + REQUIRE(mesh1.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + (2 + 3 * (parallel::size() > 1)))); + REQUIRE(mesh2.getAttribute("connectivity").read<uint64_t>() == + initial_connectivity_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("dimension").read<uint64_t>() == 3); + REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + (2 + 3 * (parallel::size() > 1))); + REQUIRE(mesh2.getAttribute("type").read<std::string>() == "polygonal"); + + HighFive::Group functions = checkpoint.getGroup("functions"); + HighFive::Group f = functions.getGroup("f"); + REQUIRE(f.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1); + HighFive::Group g = functions.getGroup("g"); + REQUIRE(g.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0); + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } + + // Revert to default value + ResumingManager::getInstance().setFilename("checkpoint.h5"); + MeshDataBaseForTests::create(); + } + + SECTION("simple if") + { + ResumingManager::destroy(); + ResumingManager::create(); + + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + ResumingManager::getInstance().setFilename(filename); + } + + const std::string filename = ResumingManager::getInstance().filename(); + + std::string data = R"( +let i:N, i = 3; + +if (true) { + checkpoint(); +} +i = 7; +checkpoint(); +)"; + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + + setResumeFrom(filename, 0); + + parallel::barrier(); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_0"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 3); + } + + ResumingManager::destroy(); + ResumingManager::create(); + ResumingManager::getInstance().setFilename(filename); + ResumingManager::getInstance().setIsResuming(true); + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 2); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } + + // Revert to default value + ResumingManager::getInstance().setFilename("checkpoint.h5"); + } + + SECTION("simple else") + { + ResumingManager::destroy(); + ResumingManager::create(); + + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + ResumingManager::getInstance().setFilename(filename); + } + + const std::string filename = ResumingManager::getInstance().filename(); + + std::string data = R"( +let i:N, i = 3; + +if (false) { +} else { + checkpoint(); +} +i = 7; +checkpoint(); +)"; + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + + setResumeFrom(filename, 0); + + parallel::barrier(); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_0"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 3); + } + + ResumingManager::destroy(); + ResumingManager::create(); + ResumingManager::getInstance().setFilename(filename); + ResumingManager::getInstance().setIsResuming(true); + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 2); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } + + // Revert to default value + ResumingManager::getInstance().setFilename("checkpoint.h5"); + } + + SECTION("simple do while") + { + ResumingManager::destroy(); + ResumingManager::create(); + + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + ResumingManager::getInstance().setFilename(filename); + } + + const std::string filename = ResumingManager::getInstance().filename(); + + std::string data = R"( +let i:N, i = 3; + +do { + checkpoint(); +} while(false); + +i = 7; +checkpoint(); +)"; + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + + setResumeFrom(filename, 0); + + parallel::barrier(); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_0"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); + REQUIRE(symbol_table1.getAttribute("i").read<uint64_t>() == 3); + } + + ResumingManager::destroy(); + ResumingManager::create(); + ResumingManager::getInstance().setFilename(filename); + ResumingManager::getInstance().setIsResuming(true); + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 2); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } + + // Revert to default value + ResumingManager::getInstance().setFilename("checkpoint.h5"); + } + + SECTION("simple while") + { + ResumingManager::destroy(); + ResumingManager::create(); + + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + ResumingManager::getInstance().setFilename(filename); + } + + const std::string filename = ResumingManager::getInstance().filename(); + + std::string data = R"( +let i:N, i = 3; + +while (true) { + checkpoint(); + break; +}; + +i = 7; +checkpoint(); +)"; + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + + setResumeFrom(filename, 0); + + parallel::barrier(); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 0); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_0"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); + HighFive::Group symbol_table2 = symbol_table1.getGroup("symbol table"); + REQUIRE(symbol_table2.getAttribute("i").read<uint64_t>() == 3); + } + + ResumingManager::destroy(); + ResumingManager::create(); + ResumingManager::getInstance().setFilename(filename); + ResumingManager::getInstance().setIsResuming(true); + + RUN_AST(data); + + { // Check checkpoint file + + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 1); + REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_1"); + + HighFive::Group singleton = checkpoint.getGroup("singleton"); + HighFive::Group execution_info = singleton.getGroup("execution_info"); + REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 2); + + HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); + REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 7); + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } + + // Revert to default value + ResumingManager::getInstance().setFilename("checkpoint.h5"); + } +#else // PUGS_HAS_HDF5 + REQUIRE_THROWS_WITH(resume(), "error: checkpoint/resume mechanism requires HDF5"); +#endif // PUGS_HAS_HDF5 +} -- GitLab