Skip to content
Snippets Groups Projects
Select Git revision
  • f63bc3456a099c763fd830cee6c3cb6a434276b4
  • develop default protected
  • origin/stage/bouguettaia
  • feature/kinetic-schemes
  • feature/reconstruction
  • feature/local-dt-fsi
  • feature/composite-scheme-sources
  • feature/composite-scheme-other-fluxes
  • feature/serraille
  • feature/variational-hydro
  • feature/composite-scheme
  • hyperplastic
  • feature/polynomials
  • feature/gks
  • feature/implicit-solver-o2
  • feature/coupling_module
  • feature/implicit-solver
  • feature/merge-local-dt-fsi
  • master protected
  • feature/escobar-smoother
  • feature/hypoelasticity-clean
  • v0.5.0 protected
  • v0.4.1 protected
  • v0.4.0 protected
  • v0.3.0 protected
  • v0.2.0 protected
  • v0.1.0 protected
  • Kidder
  • v0.0.4 protected
  • v0.0.3 protected
  • v0.0.2 protected
  • v0 protected
  • v0.0.1 protected
33 results

test_checkpointing_Checkpoint_sequential.cpp

Blame
  • test_checkpointing_Checkpoint_sequential.cpp 10.39 KiB
    #include <catch2/catch_approx.hpp>
    #include <catch2/catch_test_macros.hpp>
    #include <catch2/matchers/catch_matchers_predicate.hpp>
    
    #include <utils/pugs_config.hpp>
    
    #ifdef PUGS_HAS_HDF5
    
    #include <dev/ParallelChecker.hpp>
    #include <language/ast/ASTBuilder.hpp>
    #include <language/ast/ASTExecutionStack.hpp>
    #include <language/ast/ASTModulesImporter.hpp>
    #include <language/ast/ASTNodeDataTypeBuilder.hpp>
    #include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
    #include <language/ast/ASTNodeExpressionBuilder.hpp>
    #include <language/ast/ASTNodeTypeCleaner.hpp>
    #include <language/ast/ASTSymbolTableBuilder.hpp>
    #include <language/modules/MathModule.hpp>
    #include <language/utils/ASTCheckpointsInfo.hpp>
    #include <language/utils/CheckpointResumeRepository.hpp>
    #include <mesh/DualMeshType.hpp>
    #include <utils/ExecutionStatManager.hpp>
    #include <utils/checkpointing/DualMeshTypeHFType.hpp>
    
    class ASTCheckpointsInfoTester
    {
     private:
      ASTCheckpointsInfo m_ast_checkpoint_info;
    
     public:
      ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {}
      ~ASTCheckpointsInfoTester() = default;
    };
    
    #define RUN_AST(data)                                          \
      {                                                            \
        ExecutionStatManager::create();                            \
        ParallelChecker::create();                                 \
        CheckpointResumeRepository::create();                      \
                                                                   \
        TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \
        auto ast = ASTBuilder::build(input);                       \
                                                                   \
        ASTModulesImporter{*ast};                                  \
        ASTNodeTypeCleaner<language::import_instruction>{*ast};    \
                                                                   \
        ASTSymbolTableBuilder{*ast};                               \
        ASTNodeDataTypeBuilder{*ast};                              \
                                                                   \
        ASTNodeDeclarationToAffectationConverter{*ast};            \
        ASTNodeTypeCleaner<language::var_declaration>{*ast};       \
        ASTNodeTypeCleaner<language::fct_declaration>{*ast};       \
                                                                   \
        ASTNodeExpressionBuilder{*ast};                            \
        ExecutionPolicy exec_policy;                               \
        ASTExecutionStack::create();                               \
        ASTCheckpointsInfoTester ast_cp_info_tester{*ast};         \
        ast->execute(exec_policy);                                 \
        ASTExecutionStack::destroy();                              \
        ast->m_symbol_table->clearValues();                        \
                                                                   \
        CheckpointResumeRepository::destroy();                     \
        ParallelChecker::destroy();                                \
        ExecutionStatManager::destroy();                           \
        ast->m_symbol_table->clearValues();                        \
      }
    
    #else   // PUGS_HAS_HDF5
    
    #include <utils/checkpointing/Checkpoint.hpp>
    
    #endif   // PUGS_HAS_HDF5
    
    // clazy:excludeall=non-pod-global-static
    
    TEST_CASE("checkpointing_Checkpoint_sequential", "[utils/checkpointing]")
    {
    #ifdef PUGS_HAS_HDF5
    
      ResumingManager::destroy();
      ResumingManager::create();
    
      std::string tmp_dirname;
      {
        {
          if (parallel::rank() == 0) {
            tmp_dirname = [&]() -> std::string {
              std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
              return std::string{mkdtemp(&temp_filename[0])};
            }();
          }
          parallel::broadcast(tmp_dirname, 0);
        }
        std::filesystem::path path = tmp_dirname;
        const std::string filename = path / "checkpoint.h5";
    
        ResumingManager::getInstance().setFilename(filename);
      }
    
      std::string data = R"(
    import math;
    import mesh;
    import scheme;
    
    let f:R*R^3 -> R^3, (a, v) -> a * v;
    
    let alpha:R, alpha = 3.2;
    let u:R^3, u = [1,2,3];
    
    let m:mesh, m = cartesianMesh(0, [1,1], (10,10));
    
    let n:(N), n = (1,2,3,4);
    
    let duals:(mesh), duals = (diamondDual(m), medianDual(m));
    
    for(let i:N, i=0; i<3; ++i) {
      checkpoint();
    
      let g:R -> R^3, x -> [x, 1.2*x, 3];
      u = f(alpha,u);
    
      checkpoint();
    }
    )";
    
      const size_t initial_mesh_id         = GlobalVariableManager::instance().getMeshId();
      const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
    
      RUN_AST(data);
    
      GlobalVariableManager::instance().setMeshId(initial_mesh_id);
      GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
    
      {   // Check checkpoint file
        std::filesystem::path path = tmp_dirname;
        const std::string filename = path / "checkpoint.h5";
    
        HighFive::File file(filename, HighFive::File::ReadOnly);
    
        HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");
    
        REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1);
        REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5);
        REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5");
    
        HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table");
        REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2);
    
        HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table");
        REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2);
        REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4});
        REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) ==
                Catch::Approx(0).margin(1E-12));
    
        HighFive::Group embedded1 = symbol_table1.getGroup("embedded");
    
        HighFive::Group duals = embedded1.getGroup("duals");
        REQUIRE(duals.getAttribute("type").read<std::string>() == "(mesh)");
        HighFive::Group duals_0 = duals.getGroup("0");
        REQUIRE(duals_0.getAttribute("type").read<std::string>() == "mesh");
        REQUIRE(duals_0.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
        HighFive::Group duals_1 = duals.getGroup("1");
        REQUIRE(duals_1.getAttribute("type").read<std::string>() == "mesh");
        REQUIRE(duals_1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
    
        HighFive::Group m = embedded1.getGroup("m");
        REQUIRE(m.getAttribute("type").read<std::string>() == "mesh");
        REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id);
    
        HighFive::Group singleton        = checkpoint.getGroup("singleton");
        HighFive::Group global_variables = singleton.getGroup("global_variables");
        REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == initial_connectivity_id + 3);
        REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 3);
        HighFive::Group execution_info = singleton.getGroup("execution_info");
        REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1);
    
        HighFive::Group connectivity  = checkpoint.getGroup("connectivity");
        HighFive::Group connectivity0 = connectivity.getGroup(std::to_string(initial_connectivity_id));
        REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2);
        REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id);
        REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured");
    
        HighFive::Group connectivity1 = connectivity.getGroup(std::to_string(initial_connectivity_id + 1));
        REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 1);
        REQUIRE(connectivity1.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
        REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "dual_connectivity");
        REQUIRE(connectivity1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
    
        HighFive::Group connectivity2 = connectivity.getGroup(std::to_string(initial_connectivity_id + 2));
        REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 2);
        REQUIRE(connectivity2.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
        REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "dual_connectivity");
        REQUIRE(connectivity2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
    
        HighFive::Group mesh  = checkpoint.getGroup("mesh");
        HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id));
        REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id);
        REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2);
        REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id);
        REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal");
    
        HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + 1));
        REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
        REQUIRE(mesh1.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
        REQUIRE(mesh1.getAttribute("type").read<std::string>() == "dual_mesh");
        REQUIRE(mesh1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);
    
        HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + 2));
        REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
        REQUIRE(mesh2.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
        REQUIRE(mesh2.getAttribute("type").read<std::string>() == "dual_mesh");
        REQUIRE(mesh2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);
    
        HighFive::Group functions = checkpoint.getGroup("functions");
        HighFive::Group f         = functions.getGroup("f");
        REQUIRE(f.getAttribute("id").read<uint64_t>() == 0);
        REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1);
        HighFive::Group g = functions.getGroup("g");
        REQUIRE(g.getAttribute("id").read<uint64_t>() == 1);
        REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0);
      }
    
      parallel::barrier();
      if (parallel::rank() == 0) {
        std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
      }
    
      // Revert to default value
      ResumingManager::getInstance().setFilename("checkpoint.h5");
    
    #else    // PUGS_HAS_HDF5
      REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5");
    #endif   // PUGS_HAS_HDF5
    }