Select Git revision
SchemeModule.cpp
-
Stéphane Del Pino authored
Now modules can/must register the operators they provide
Stéphane Del Pino authoredNow modules can/must register the operators they provide
test_checkpointing_Checkpoint_sequential.cpp 10.39 KiB
#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_predicate.hpp>
#include <utils/pugs_config.hpp>
#ifdef PUGS_HAS_HDF5
#include <dev/ParallelChecker.hpp>
#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/ast/ASTModulesImporter.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <language/modules/MathModule.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <mesh/DualMeshType.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/checkpointing/DualMeshTypeHFType.hpp>
class ASTCheckpointsInfoTester
{
private:
ASTCheckpointsInfo m_ast_checkpoint_info;
public:
ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {}
~ASTCheckpointsInfoTester() = default;
};
#define RUN_AST(data) \
{ \
ExecutionStatManager::create(); \
ParallelChecker::create(); \
CheckpointResumeRepository::create(); \
\
TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \
auto ast = ASTBuilder::build(input); \
\
ASTModulesImporter{*ast}; \
ASTNodeTypeCleaner<language::import_instruction>{*ast}; \
\
ASTSymbolTableBuilder{*ast}; \
ASTNodeDataTypeBuilder{*ast}; \
\
ASTNodeDeclarationToAffectationConverter{*ast}; \
ASTNodeTypeCleaner<language::var_declaration>{*ast}; \
ASTNodeTypeCleaner<language::fct_declaration>{*ast}; \
\
ASTNodeExpressionBuilder{*ast}; \
ExecutionPolicy exec_policy; \
ASTExecutionStack::create(); \
ASTCheckpointsInfoTester ast_cp_info_tester{*ast}; \
ast->execute(exec_policy); \
ASTExecutionStack::destroy(); \
ast->m_symbol_table->clearValues(); \
\
CheckpointResumeRepository::destroy(); \
ParallelChecker::destroy(); \
ExecutionStatManager::destroy(); \
ast->m_symbol_table->clearValues(); \
}
#else // PUGS_HAS_HDF5
#include <utils/checkpointing/Checkpoint.hpp>
#endif // PUGS_HAS_HDF5
// clazy:excludeall=non-pod-global-static
TEST_CASE("checkpointing_Checkpoint_sequential", "[utils/checkpointing]")
{
#ifdef PUGS_HAS_HDF5
ResumingManager::destroy();
ResumingManager::create();
std::string tmp_dirname;
{
{
if (parallel::rank() == 0) {
tmp_dirname = [&]() -> std::string {
std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
return std::string{mkdtemp(&temp_filename[0])};
}();
}
parallel::broadcast(tmp_dirname, 0);
}
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
ResumingManager::getInstance().setFilename(filename);
}
std::string data = R"(
import math;
import mesh;
import scheme;
let f:R*R^3 -> R^3, (a, v) -> a * v;
let alpha:R, alpha = 3.2;
let u:R^3, u = [1,2,3];
let m:mesh, m = cartesianMesh(0, [1,1], (10,10));
let n:(N), n = (1,2,3,4);
let duals:(mesh), duals = (diamondDual(m), medianDual(m));
for(let i:N, i=0; i<3; ++i) {
checkpoint();
let g:R -> R^3, x -> [x, 1.2*x, 3];
u = f(alpha,u);
checkpoint();
}
)";
const size_t initial_mesh_id = GlobalVariableManager::instance().getMeshId();
const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
RUN_AST(data);
GlobalVariableManager::instance().setMeshId(initial_mesh_id);
GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
{ // Check checkpoint file
std::filesystem::path path = tmp_dirname;
const std::string filename = path / "checkpoint.h5";
HighFive::File file(filename, HighFive::File::ReadOnly);
HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5);
REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5");
HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table");
REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2);
HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table");
REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2);
REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4});
REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) ==
Catch::Approx(0).margin(1E-12));
HighFive::Group embedded1 = symbol_table1.getGroup("embedded");
HighFive::Group duals = embedded1.getGroup("duals");
REQUIRE(duals.getAttribute("type").read<std::string>() == "(mesh)");
HighFive::Group duals_0 = duals.getGroup("0");
REQUIRE(duals_0.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(duals_0.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
HighFive::Group duals_1 = duals.getGroup("1");
REQUIRE(duals_1.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(duals_1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
HighFive::Group m = embedded1.getGroup("m");
REQUIRE(m.getAttribute("type").read<std::string>() == "mesh");
REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id);
HighFive::Group singleton = checkpoint.getGroup("singleton");
HighFive::Group global_variables = singleton.getGroup("global_variables");
REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == initial_connectivity_id + 3);
REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 3);
HighFive::Group execution_info = singleton.getGroup("execution_info");
REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1);
HighFive::Group connectivity = checkpoint.getGroup("connectivity");
HighFive::Group connectivity0 = connectivity.getGroup(std::to_string(initial_connectivity_id));
REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured");
HighFive::Group connectivity1 = connectivity.getGroup(std::to_string(initial_connectivity_id + 1));
REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 1);
REQUIRE(connectivity1.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "dual_connectivity");
REQUIRE(connectivity1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
HighFive::Group connectivity2 = connectivity.getGroup(std::to_string(initial_connectivity_id + 2));
REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 2);
REQUIRE(connectivity2.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "dual_connectivity");
REQUIRE(connectivity2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
HighFive::Group mesh = checkpoint.getGroup("mesh");
HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id));
REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id);
REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2);
REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal");
HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + 1));
REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
REQUIRE(mesh1.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh1.getAttribute("type").read<std::string>() == "dual_mesh");
REQUIRE(mesh1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + 2));
REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
REQUIRE(mesh2.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
REQUIRE(mesh2.getAttribute("type").read<std::string>() == "dual_mesh");
REQUIRE(mesh2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
HighFive::Group functions = checkpoint.getGroup("functions");
HighFive::Group f = functions.getGroup("f");
REQUIRE(f.getAttribute("id").read<uint64_t>() == 0);
REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1);
HighFive::Group g = functions.getGroup("g");
REQUIRE(g.getAttribute("id").read<uint64_t>() == 1);
REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0);
}
parallel::barrier();
if (parallel::rank() == 0) {
std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
}
// Revert to default value
ResumingManager::getInstance().setFilename("checkpoint.h5");
#else // PUGS_HAS_HDF5
REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5");
#endif // PUGS_HAS_HDF5
}