From 198e751a26abbeff812e19beaec057cecf03c1e8 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Thu, 17 Oct 2024 00:55:34 +0200
Subject: [PATCH] Add tests for SetResumeFrom checkpointing

---
 src/utils/checkpointing/SetResumeFrom.cpp  |  11 +-
 src/utils/checkpointing/SetResumeFrom.hpp  |   3 +-
 tests/CMakeLists.txt                       |   1 +
 tests/test_checkpointing_SetResumeFrom.cpp | 112 +++++++++++++++++++++
 4 files changed, 122 insertions(+), 5 deletions(-)
 create mode 100644 tests/test_checkpointing_SetResumeFrom.cpp

diff --git a/src/utils/checkpointing/SetResumeFrom.cpp b/src/utils/checkpointing/SetResumeFrom.cpp
index 2d1f5e5ca..81502da2f 100644
--- a/src/utils/checkpointing/SetResumeFrom.cpp
+++ b/src/utils/checkpointing/SetResumeFrom.cpp
@@ -11,9 +11,10 @@
 #include <utils/HighFivePugsUtils.hpp>
 
 void
-setResumeFrom(const std::string& filename, const uint64_t& checkpoint_number)
+setResumeFrom(const std::string& filename, const uint64_t& checkpoint_number, std::ostream& os)
 {
   try {
+    HighFive::SilenceHDF5 m_silence_hdf5{true};
     HighFive::File file(filename, HighFive::File::ReadWrite);
     const std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number);
 
@@ -29,18 +30,20 @@ setResumeFrom(const std::string& filename, const uint64_t& checkpoint_number)
       file.unlink("resuming_checkpoint");
     }
     file.createHardLink("resuming_checkpoint", checkpoint);
-    std::cout << "Resuming checkpoint " << rang::style::bold << "successfully" << rang::style::reset << " set to "
-              << rang::fgB::yellow << checkpoint_number << rang::fg::reset << '\n';
+    os << "Resuming checkpoint " << rang::style::bold << "successfully" << rang::style::reset << " set to "
+       << rang::fgB::yellow << checkpoint_number << rang::fg::reset << '\n';
   }
+  // LCOV_EXCL_START
   catch (HighFive::Exception& e) {
     throw NormalError(e.what());
   }
+  // LCOV_EXCL_STOP
 }
 
 #else   // PUGS_HAS_HDF5
 
 void
-setResumeFrom(const std::string&, const uint64_t&)
+setResumeFrom(const std::string&, const uint64_t&, std::ostream&)
 {
   std::cerr << rang::fgB::red << "error: " << rang::fg::reset << "setting resuming checkpoint requires HDF5\n";
 }
diff --git a/src/utils/checkpointing/SetResumeFrom.hpp b/src/utils/checkpointing/SetResumeFrom.hpp
index c8e44e6e4..842c63893 100644
--- a/src/utils/checkpointing/SetResumeFrom.hpp
+++ b/src/utils/checkpointing/SetResumeFrom.hpp
@@ -2,8 +2,9 @@
 #define SET_RESUME_FROM_HPP
 
 #include <cstdint>
+#include <iostream>
 #include <string>
 
-void setResumeFrom(const std::string& filename, const uint64_t& checkpoint_number);
+void setResumeFrom(const std::string& filename, const uint64_t& checkpoint_number, std::ostream& os = std::cout);
 
 #endif   // SET_RESUME_FROM_HPP
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index cc7c59a0c..1015ec895 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -160,6 +160,7 @@ add_executable (unit_tests
   set(checkpointing_TESTS
     test_checkpointing_PrintScriptFrom.cpp
     test_checkpointing_ResumingUtils.cpp
+    test_checkpointing_SetResumeFrom.cpp
   )
 
 if(PUGS_HAS_HDF5)
diff --git a/tests/test_checkpointing_SetResumeFrom.cpp b/tests/test_checkpointing_SetResumeFrom.cpp
new file mode 100644
index 000000000..eb945250f
--- /dev/null
+++ b/tests/test_checkpointing_SetResumeFrom.cpp
@@ -0,0 +1,112 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/HighFivePugsUtils.hpp>
+#include <utils/Messenger.hpp>
+#include <utils/checkpointing/SetResumeFrom.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+TEST_CASE("SetResumeFrom", "[utils/checkpointing]")
+{
+#ifdef PUGS_HAS_HDF5
+
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    const std::string data_file0 = R"(Un tiens vaut mieux que deux tu l'auras,
+Un tiens vaut mieux que deux tu l'auras,...)";
+    const std::string data_file1 = R"(All work and no play makes Jack a dull boy,
+All work and no play makes Jack a dull boy,...)";
+    const std::string data_file2 = R"(solo trabajo y nada de juego hacen de Jack un chico aburrido,
+solo trabajo y nada de juego hacen de Jack un chico aburrido,...)";
+
+    {
+      HighFive::FileAccessProps fapl;
+      fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+      fapl.add(HighFive::MPIOCollectiveMetadata{});
+      HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+      file.createGroup("/checkpoint_0").createAttribute("data.pgs", data_file0);
+      file.createGroup("/checkpoint_1").createAttribute("data.pgs", data_file1);
+      file.createGroup("/checkpoint_2").createAttribute("data.pgs", data_file2);
+    }
+
+    {
+      std::ostringstream os;
+      setResumeFrom(filename, 0, os);
+      REQUIRE(os.str() == "Resuming checkpoint successfully set to 0\n");
+
+      HighFive::FileAccessProps fapl;
+      fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+      fapl.add(HighFive::MPIOCollectiveMetadata{});
+      HighFive::File file = HighFive::File(filename, HighFive::File::ReadOnly, fapl);
+      REQUIRE(file.getGroup("/resuming_checkpoint").getAttribute("data.pgs").read<std::string>() == data_file0);
+    }
+
+    {
+      std::ostringstream os;
+      setResumeFrom(filename, 1, os);
+      REQUIRE(os.str() == "Resuming checkpoint successfully set to 1\n");
+
+      HighFive::FileAccessProps fapl;
+      fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+      fapl.add(HighFive::MPIOCollectiveMetadata{});
+      HighFive::File file = HighFive::File(filename, HighFive::File::ReadOnly, fapl);
+      REQUIRE(file.getGroup("/resuming_checkpoint").getAttribute("data.pgs").read<std::string>() == data_file1);
+    }
+
+    {
+      std::ostringstream os;
+      setResumeFrom(filename, 2, os);
+      REQUIRE(os.str() == "Resuming checkpoint successfully set to 2\n");
+
+      HighFive::FileAccessProps fapl;
+      fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+      fapl.add(HighFive::MPIOCollectiveMetadata{});
+      HighFive::File file = HighFive::File(filename, HighFive::File::ReadOnly, fapl);
+      REQUIRE(file.getGroup("/resuming_checkpoint").getAttribute("data.pgs").read<std::string>() == data_file2);
+    }
+
+    {
+      std::ostringstream error_msg;
+      error_msg << "error: cannot find checkpoint " << 12 << " in " << filename;
+      REQUIRE_THROWS_WITH(setResumeFrom(filename, 12), error_msg.str());
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+
+#else   // PUGS_HAS_HDF5
+
+  if (parallel::rank() == 0) {
+    std::cerr.setstate(std::ios::badbit);
+  }
+
+  std::ostringstream os;
+  REQUIRE_NOTHROW(setResumeFrom("foo.h5", 0, os));
+
+  if (parallel::rank() == 0) {
+    std::cerr.clear();
+  }
+
+  REQUIRE(os.str() == "");
+
+#endif   // PUGS_HAS_HDF5
+}
-- 
GitLab