#include <catch2/catch_approx.hpp> #include <catch2/catch_test_macros.hpp> #include <catch2/matchers/catch_matchers_predicate.hpp> #include <utils/pugs_config.hpp> #ifdef PUGS_HAS_HDF5 #include <dev/ParallelChecker.hpp> #include <language/ast/ASTBuilder.hpp> #include <language/ast/ASTExecutionStack.hpp> #include <language/ast/ASTModulesImporter.hpp> #include <language/ast/ASTNodeDataTypeBuilder.hpp> #include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp> #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/modules/MathModule.hpp> #include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/CheckpointResumeRepository.hpp> #include <mesh/DualMeshType.hpp> #include <utils/ExecutionStatManager.hpp> #include <utils/checkpointing/DualMeshTypeHFType.hpp> class ASTCheckpointsInfoTester { private: ASTCheckpointsInfo m_ast_checkpoint_info; public: ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {} ~ASTCheckpointsInfoTester() = default; }; #define RUN_AST(data) \ { \ ExecutionStatManager::create(); \ ParallelChecker::create(); \ CheckpointResumeRepository::create(); \ \ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ \ ASTModulesImporter{*ast}; \ ASTNodeTypeCleaner<language::import_instruction>{*ast}; \ \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ \ ASTNodeDeclarationToAffectationConverter{*ast}; \ ASTNodeTypeCleaner<language::var_declaration>{*ast}; \ ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \ \ ASTNodeExpressionBuilder{*ast}; \ ExecutionPolicy exec_policy; \ ASTExecutionStack::create(); \ ASTCheckpointsInfoTester ast_cp_info_tester{*ast}; \ ast->execute(exec_policy); \ ASTExecutionStack::destroy(); \ ast->m_symbol_table->clearValues(); \ \ CheckpointResumeRepository::destroy(); \ ParallelChecker::destroy(); \ ExecutionStatManager::destroy(); \ ast->m_symbol_table->clearValues(); \ } #else // PUGS_HAS_HDF5 #include <utils/checkpointing/Checkpoint.hpp> #endif // PUGS_HAS_HDF5 // clazy:excludeall=non-pod-global-static TEST_CASE("checkpointing_Checkpoint_sequential", "[utils/checkpointing]") { #ifdef PUGS_HAS_HDF5 ResumingManager::destroy(); ResumingManager::create(); std::string tmp_dirname; { { if (parallel::rank() == 0) { tmp_dirname = [&]() -> std::string { std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; return std::string{mkdtemp(&temp_filename[0])}; }(); } parallel::broadcast(tmp_dirname, 0); } std::filesystem::path path = tmp_dirname; const std::string filename = path / "checkpoint.h5"; ResumingManager::getInstance().setFilename(filename); } std::string data = R"( import math; import mesh; import scheme; let f:R*R^3 -> R^3, (a, v) -> a * v; let alpha:R, alpha = 3.2; let u:R^3, u = [1,2,3]; let m:mesh, m = cartesianMesh(0, [1,1], (10,10)); let n:(N), n = (1,2,3,4); let duals:(mesh), duals = (diamondDual(m), medianDual(m)); for(let i:N, i=0; i<3; ++i) { checkpoint(); let g:R -> R^3, x -> [x, 1.2*x, 3]; u = f(alpha,u); checkpoint(); } )"; const size_t initial_mesh_id = GlobalVariableManager::instance().getMeshId(); const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId(); RUN_AST(data); GlobalVariableManager::instance().setMeshId(initial_mesh_id); GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id); { // Check checkpoint file std::filesystem::path path = tmp_dirname; const std::string filename = path / "checkpoint.h5"; HighFive::File file(filename, HighFive::File::ReadOnly); HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1); REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5); REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5"); HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table"); REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2); HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table"); REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2); REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4}); REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) == Catch::Approx(0).margin(1E-12)); HighFive::Group embedded1 = symbol_table1.getGroup("embedded"); HighFive::Group duals = embedded1.getGroup("duals"); REQUIRE(duals.getAttribute("type").read<std::string>() == "(mesh)"); HighFive::Group duals_0 = duals.getGroup("0"); REQUIRE(duals_0.getAttribute("type").read<std::string>() == "mesh"); REQUIRE(duals_0.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1); HighFive::Group duals_1 = duals.getGroup("1"); REQUIRE(duals_1.getAttribute("type").read<std::string>() == "mesh"); REQUIRE(duals_1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2); HighFive::Group m = embedded1.getGroup("m"); REQUIRE(m.getAttribute("type").read<std::string>() == "mesh"); REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id); HighFive::Group singleton = checkpoint.getGroup("singleton"); HighFive::Group global_variables = singleton.getGroup("global_variables"); REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == initial_connectivity_id + 3); REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 3); HighFive::Group execution_info = singleton.getGroup("execution_info"); REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1); HighFive::Group connectivity = checkpoint.getGroup("connectivity"); HighFive::Group connectivity0 = connectivity.getGroup(std::to_string(initial_connectivity_id)); REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2); REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id); REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured"); HighFive::Group connectivity1 = connectivity.getGroup(std::to_string(initial_connectivity_id + 1)); REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 1); REQUIRE(connectivity1.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id); REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "dual_connectivity"); REQUIRE(connectivity1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond); HighFive::Group connectivity2 = connectivity.getGroup(std::to_string(initial_connectivity_id + 2)); REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 2); REQUIRE(connectivity2.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id); REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "dual_connectivity"); REQUIRE(connectivity2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median); HighFive::Group mesh = checkpoint.getGroup("mesh"); HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id)); REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id); REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2); REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id); REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal"); HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + 1)); REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1); REQUIRE(mesh1.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id); REQUIRE(mesh1.getAttribute("type").read<std::string>() == "dual_mesh"); REQUIRE(mesh1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond); HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + 2)); REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2); REQUIRE(mesh2.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id); REQUIRE(mesh2.getAttribute("type").read<std::string>() == "dual_mesh"); REQUIRE(mesh2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median); HighFive::Group functions = checkpoint.getGroup("functions"); HighFive::Group f = functions.getGroup("f"); REQUIRE(f.getAttribute("id").read<uint64_t>() == 0); REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1); HighFive::Group g = functions.getGroup("g"); REQUIRE(g.getAttribute("id").read<uint64_t>() == 1); REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0); } parallel::barrier(); if (parallel::rank() == 0) { std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); } // Revert to default value ResumingManager::getInstance().setFilename("checkpoint.h5"); #else // PUGS_HAS_HDF5 REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5"); #endif // PUGS_HAS_HDF5 }