Skip to content
Snippets Groups Projects
Commit 2dba22b9 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add tests for Checkpoint

parent 38fe1ef1
No related branches found
No related tags found
1 merge request!199Integrate checkpointing
......@@ -22,6 +22,9 @@ class ASTCheckpointsInfo
// The only place where the ASTCheckpointsInfo can be built
friend void parser(const std::string& filename);
// to allow special manipulations in tests
friend class ASTCheckpointsInfoTester;
ASTCheckpointsInfo(const ASTNode& root_node);
public:
......
......@@ -171,7 +171,11 @@ initialize(int& argc, char* argv[])
}
ResumingManager::getInstance().setIsResuming(is_resuming);
if (is_resuming) {
ResumingManager::getInstance().setFilename(filename);
} else {
ResumingManager::getInstance().setFilename("checkpoint.h5");
}
ExecutionStatManager::getInstance().setPrint(print_exec_stat);
BacktraceManager::setShow(show_backtrace);
......
......@@ -44,7 +44,7 @@ checkpoint()
const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite;
HighFive::File file("checkpoint.h5", file_openmode, create_props, fapl);
HighFive::File file(ResumingManager::getInstance().filename(), file_openmode, create_props, fapl);
std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number);
......@@ -202,10 +202,14 @@ checkpoint()
data[i], file, checkpoint, symbol_table_group);
}
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data type");
// LCOV_EXCL_STOP
}
} else {
// LCOV_EXCL_START
throw UnexpectedError("unexpected data type");
// LCOV_EXCL_STOP
}
},
symbol.attributes().value());
......@@ -228,9 +232,11 @@ checkpoint()
}
file.createHardLink("resuming_checkpoint", checkpoint);
}
// LCOV_EXCL_START
catch (HighFive::Exception& e) {
throw NormalError(e.what());
}
// LCOV_EXCL_STOP
}
#else // PUGS_HAS_HDF5
......
......@@ -4,6 +4,11 @@ include_directories(${PUGS_SOURCE_DIR}/src)
include_directories(${PUGS_BINARY_DIR}/src)
include_directories(${PUGS_SOURCE_DIR}/tests)
# These should enventually integrate parallel tests
set(checkpointing_sequential_TESTS
test_checkpointing_Checkpoint_sequential.cpp
)
add_executable (unit_tests
test_main.cpp
test_AffectationProcessor.cpp
......@@ -155,9 +160,11 @@ add_executable (unit_tests
test_UnaryOperatorMangler.cpp
test_Vector.cpp
test_WhileProcessor.cpp
${checkpointing_sequential_TESTS}
)
set(checkpointing_TESTS
test_checkpointing_Checkpoint.cpp
test_checkpointing_PrintCheckpointInfo.cpp
test_checkpointing_PrintScriptFrom.cpp
test_checkpointing_ResumingManager.cpp
......
#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_predicate.hpp>
#include <utils/pugs_config.hpp>
#ifdef PUGS_HAS_HDF5
#include <dev/ParallelChecker.hpp>
#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/ast/ASTModulesImporter.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <language/modules/MathModule.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <utils/ExecutionStatManager.hpp>
class ASTCheckpointsInfoTester
{
private:
ASTCheckpointsInfo m_ast_checkpoint_info;
public:
ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {}
~ASTCheckpointsInfoTester() = default;
};
#define RUN_AST(data) \
{ \
ExecutionStatManager::create(); \
ParallelChecker::create(); \
CheckpointResumeRepository::create(); \
\
TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \
auto ast = ASTBuilder::build(input); \
\
ASTModulesImporter{*ast}; \
ASTNodeTypeCleaner<language::import_instruction>{*ast}; \
\
ASTSymbolTableBuilder{*ast}; \
ASTNodeDataTypeBuilder{*ast}; \
\
ASTNodeDeclarationToAffectationConverter{*ast}; \
ASTNodeTypeCleaner<language::var_declaration>{*ast}; \
ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \
\
ASTNodeExpressionBuilder{*ast}; \
ExecutionPolicy exec_policy; \
ASTExecutionStack::create(); \
ASTCheckpointsInfoTester ast_cp_info_tester{*ast}; \
ast->execute(exec_policy); \
ASTExecutionStack::destroy(); \
ast->m_symbol_table->clearValues(); \
\
CheckpointResumeRepository::destroy(); \
ParallelChecker::destroy(); \
ExecutionStatManager::destroy(); \
}
#else // PUGS_HAS_HDF5
#include <utils/checkpointing/Checkpoint.hpp>
#endif // PUGS_HAS_HDF5
// clazy:excludeall=non-pod-global-static
TEST_CASE("checkpointing_Checkpoint", "[utils/checkpointing]")
{
#ifdef PUGS_HAS_HDF5
std::string tmp_dirname;
{
{
if (parallel::rank() == 0) {
tmp_dirname = [&]() -> std::string {
std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
return std::string{mkdtemp(&temp_filename[0])};
}();
}
parallel::broadcast(tmp_dirname, 0);
}
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
ResumingManager::getInstance().setFilename(filename);
}
std::string data = R"(
import math;
import mesh;
import scheme;
let f:R*R^3 -> R^3, (a, v) -> a * v;
let alpha:R, alpha = 3.2;
let u:R^3, u = [1,2,3];
let m:mesh, m = cartesianMesh(0, [1,1], (10,10));
let n:(N), n = (1,2,3,4);
for(let i:N, i=0; i<3; ++i) {
checkpoint();
let g:R -> R^3, x -> [x, 1.2*x, 3];
u = f(alpha,u);
checkpoint();
}
)";
const size_t initial_mesh_id = GlobalVariableManager::instance().getMeshId();
const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
RUN_AST(data);
GlobalVariableManager::instance().setMeshId(initial_mesh_id);
GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
{ // Check checkpoint file
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
HighFive::File file(filename, HighFive::File::ReadOnly);
HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5);
REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5");
HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table");
REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2);
HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table");
REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2);
REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4});
REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) ==
Catch::Approx(0).margin(1E-12));
HighFive::Group embedded1 = symbol_table1.getGroup("embedded");
HighFive::Group m = embedded1.getGroup("m");
REQUIRE(m.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1));
HighFive::Group singleton = checkpoint.getGroup("singleton");
HighFive::Group global_variables = singleton.getGroup("global_variables");
REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() ==
initial_connectivity_id + 1 + (parallel::size() > 1));
REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 1 + (parallel::size() > 1));
HighFive::Group execution_info = singleton.getGroup("execution_info");
REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1);
HighFive::Group connectivity = checkpoint.getGroup("connectivity");
HighFive::Group connectivity0 =
connectivity.getGroup(std::to_string(initial_connectivity_id + (parallel::size() > 1)));
REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1));
REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured");
HighFive::Group mesh = checkpoint.getGroup("mesh");
HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id + (parallel::size() > 1)));
REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id + (parallel::size() > 1));
REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id + (parallel::size() > 1));
REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal");
HighFive::Group functions = checkpoint.getGroup("functions");
HighFive::Group f = functions.getGroup("f");
REQUIRE(f.getAttribute("id").read<uint64_t>() == 0);
REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1);
HighFive::Group g = functions.getGroup("g");
REQUIRE(g.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0);
}
parallel::barrier();
if (parallel::rank() == 0) {
std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
}
#else // PUGS_HAS_HDF5
REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5");
#endif // PUGS_HAS_HDF5
}
#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_predicate.hpp>
#include <utils/pugs_config.hpp>
#ifdef PUGS_HAS_HDF5
#include <dev/ParallelChecker.hpp>
#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/ast/ASTModulesImporter.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <language/modules/MathModule.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <mesh/DualMeshType.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/checkpointing/DualMeshTypeHFType.hpp>
class ASTCheckpointsInfoTester
{
private:
ASTCheckpointsInfo m_ast_checkpoint_info;
public:
ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {}
~ASTCheckpointsInfoTester() = default;
};
#define RUN_AST(data) \
{ \
ExecutionStatManager::create(); \
ParallelChecker::create(); \
CheckpointResumeRepository::create(); \
\
TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \
auto ast = ASTBuilder::build(input); \
\
ASTModulesImporter{*ast}; \
ASTNodeTypeCleaner<language::import_instruction>{*ast}; \
\
ASTSymbolTableBuilder{*ast}; \
ASTNodeDataTypeBuilder{*ast}; \
\
ASTNodeDeclarationToAffectationConverter{*ast}; \
ASTNodeTypeCleaner<language::var_declaration>{*ast}; \
ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \
\
ASTNodeExpressionBuilder{*ast}; \
ExecutionPolicy exec_policy; \
ASTExecutionStack::create(); \
ASTCheckpointsInfoTester ast_cp_info_tester{*ast}; \
ast->execute(exec_policy); \
ASTExecutionStack::destroy(); \
ast->m_symbol_table->clearValues(); \
\
CheckpointResumeRepository::destroy(); \
ParallelChecker::destroy(); \
ExecutionStatManager::destroy(); \
}
#else // PUGS_HAS_HDF5
#include <utils/checkpointing/Checkpoint.hpp>
#endif // PUGS_HAS_HDF5
// clazy:excludeall=non-pod-global-static
TEST_CASE("checkpointing_Checkpoint_sequential", "[utils/checkpointing]")
{
#ifdef PUGS_HAS_HDF5
std::string tmp_dirname;
{
{
if (parallel::rank() == 0) {
tmp_dirname = [&]() -> std::string {
std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
return std::string{mkdtemp(&temp_filename[0])};
}();
}
parallel::broadcast(tmp_dirname, 0);
}
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
ResumingManager::getInstance().setFilename(filename);
}
std::string data = R"(
import math;
import mesh;
import scheme;
let f:R*R^3 -> R^3, (a, v) -> a * v;
let alpha:R, alpha = 3.2;
let u:R^3, u = [1,2,3];
let m:mesh, m = cartesianMesh(0, [1,1], (10,10));
let n:(N), n = (1,2,3,4);
let duals:(mesh), duals = (diamondDual(m), medianDual(m));
for(let i:N, i=0; i<3; ++i) {
checkpoint();
let g:R -> R^3, x -> [x, 1.2*x, 3];
u = f(alpha,u);
checkpoint();
}
)";
const size_t initial_mesh_id = GlobalVariableManager::instance().getMeshId();
const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
RUN_AST(data);
GlobalVariableManager::instance().setMeshId(initial_mesh_id);
GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
{ // Check checkpoint file
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
HighFive::File file(filename, HighFive::File::ReadOnly);
HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5);
REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5");
HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table");
REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2);
HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table");
REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2);
REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4});
REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) ==
Catch::Approx(0).margin(1E-12));
HighFive::Group embedded1 = symbol_table1.getGroup("embedded");
HighFive::Group duals = embedded1.getGroup("duals");
REQUIRE(duals.getAttribute("type").read<std::string>() == "(mesh)");
HighFive::Group duals_0 = duals.getGroup("0");
REQUIRE(duals_0.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(duals_0.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
HighFive::Group duals_1 = duals.getGroup("1");
REQUIRE(duals_1.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(duals_1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
HighFive::Group m = embedded1.getGroup("m");
REQUIRE(m.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id);
HighFive::Group singleton = checkpoint.getGroup("singleton");
HighFive::Group global_variables = singleton.getGroup("global_variables");
REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == initial_connectivity_id + 3);
REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 3);
HighFive::Group execution_info = singleton.getGroup("execution_info");
REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1);
HighFive::Group connectivity = checkpoint.getGroup("connectivity");
HighFive::Group connectivity0 = connectivity.getGroup(std::to_string(initial_connectivity_id));
REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured");
HighFive::Group connectivity1 = connectivity.getGroup(std::to_string(initial_connectivity_id + 1));
REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 1);
REQUIRE(connectivity1.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "dual_connectivity");
REQUIRE(connectivity1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
HighFive::Group connectivity2 = connectivity.getGroup(std::to_string(initial_connectivity_id + 2));
REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 2);
REQUIRE(connectivity2.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "dual_connectivity");
REQUIRE(connectivity2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
HighFive::Group mesh = checkpoint.getGroup("mesh");
HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id));
REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id);
REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal");
HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + 1));
REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
REQUIRE(mesh1.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh1.getAttribute("type").read<std::string>() == "dual_mesh");
REQUIRE(mesh1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + 2));
REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
REQUIRE(mesh2.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh2.getAttribute("type").read<std::string>() == "dual_mesh");
REQUIRE(mesh2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
HighFive::Group functions = checkpoint.getGroup("functions");
HighFive::Group f = functions.getGroup("f");
REQUIRE(f.getAttribute("id").read<uint64_t>() == 0);
REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1);
HighFive::Group g = functions.getGroup("g");
REQUIRE(g.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0);
}
parallel::barrier();
if (parallel::rank() == 0) {
std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
}
// Revert to default value
ResumingManager::getInstance().setFilename("checkpoint.h5");
#else // PUGS_HAS_HDF5
REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5");
#endif // PUGS_HAS_HDF5
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment