From 491683852667c49b4c111aacac6ffab8c72e8d68 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 24 May 2024 00:40:08 +0200
Subject: [PATCH] Add checkpoint/resume handling of parallel checker

---
 src/dev/ParallelChecker.hpp                   | 19 ++++++++++++++++++-
 src/utils/checkpointing/Checkpoint.cpp        |  7 +++++++
 .../checkpointing/ParallelCheckerHFType.hpp   | 16 ++++++++++++++++
 src/utils/checkpointing/Resume.cpp            |  7 +++++++
 4 files changed, 48 insertions(+), 1 deletion(-)
 create mode 100644 src/utils/checkpointing/ParallelCheckerHFType.hpp

diff --git a/src/dev/ParallelChecker.hpp b/src/dev/ParallelChecker.hpp
index 08b6de991..e37729292 100644
--- a/src/dev/ParallelChecker.hpp
+++ b/src/dev/ParallelChecker.hpp
@@ -13,6 +13,7 @@
 #include <utils/Filesystem.hpp>
 #include <utils/Messenger.hpp>
 #include <utils/SourceLocation.hpp>
+#include <utils/checkpointing/ResumingManager.hpp>
 
 #include <fstream>
 
@@ -1248,6 +1249,22 @@ class ParallelChecker
     return *m_instance;
   }
 
+  size_t
+  tag() const
+  {
+    return m_tag;
+  }
+
+  void
+  setTag(size_t tag)
+  {
+    if (ResumingManager::getInstance().isResuming()) {
+      m_tag = tag;
+    } else {
+      throw UnexpectedError("Cannot modify parallel checker tag if not resuming");
+    }
+  }
+
   Mode
   mode() const
   {
@@ -1257,7 +1274,7 @@ class ParallelChecker
   void
   setMode(const Mode& mode)
   {
-    if (m_tag != 0) {
+    if ((m_tag != 0) and not ResumingManager::getInstance().isResuming()) {
       throw UnexpectedError("Cannot modify parallel checker mode if it was already used");
     }
 
diff --git a/src/utils/checkpointing/Checkpoint.cpp b/src/utils/checkpointing/Checkpoint.cpp
index 245871c0d..70166a861 100644
--- a/src/utils/checkpointing/Checkpoint.cpp
+++ b/src/utils/checkpointing/Checkpoint.cpp
@@ -21,6 +21,7 @@
 #ifdef PUGS_HAS_HDF5
 
 #include <algebra/LinearSolverOptions.hpp>
+#include <dev/ParallelChecker.hpp>
 #include <language/utils/ASTNodeDataTypeTraits.hpp>
 #include <language/utils/CheckpointResumeRepository.hpp>
 #include <language/utils/DataHandler.hpp>
@@ -29,6 +30,7 @@
 #include <utils/RandomEngine.hpp>
 
 #include <utils/checkpointing/LinearSolverOptionsHFType.hpp>
+#include <utils/checkpointing/ParallelCheckerHFType.hpp>
 
 void
 checkpoint()
@@ -83,6 +85,11 @@ checkpoint()
       execution_info_group.createAttribute("cumulative_total_cpu_time",
                                            ExecutionStatManager::getInstance().getCumulativeTotalCPUTime());
     }
+    {
+      HighFive::Group parallel_checker_group = checkpoint.createGroup("singleton/parallel_checker");
+      parallel_checker_group.createAttribute("tag", ParallelChecker::instance().tag());
+      parallel_checker_group.createAttribute("mode", ParallelChecker::instance().mode());
+    }
     {
       HighFive::Group linear_solver_options_default_group =
         checkpoint.createGroup("singleton/linear_solver_options_default");
diff --git a/src/utils/checkpointing/ParallelCheckerHFType.hpp b/src/utils/checkpointing/ParallelCheckerHFType.hpp
new file mode 100644
index 000000000..27cd6412b
--- /dev/null
+++ b/src/utils/checkpointing/ParallelCheckerHFType.hpp
@@ -0,0 +1,16 @@
+#ifndef PARALLEL_CHECKER_HF_TYPE_HPP
+#define PARALLEL_CHECKER_HF_TYPE_HPP
+
+#include <dev/ParallelChecker.hpp>
+#include <utils/checkpointing/CheckpointUtils.hpp>
+
+HighFive::EnumType<ParallelChecker::Mode> PUGS_INLINE
+create_enum_ParallelChecker_mode()
+{
+  return {{"automatic", ParallelChecker::Mode::automatic},
+          {"read", ParallelChecker::Mode::read},
+          {"write", ParallelChecker::Mode::write}};
+}
+HIGHFIVE_REGISTER_TYPE(ParallelChecker::Mode, create_enum_ParallelChecker_mode)
+
+#endif   // PARALLEL_CHECKER_HF_TYPE_HPP
diff --git a/src/utils/checkpointing/Resume.cpp b/src/utils/checkpointing/Resume.cpp
index 27f8e3db1..757458fc4 100644
--- a/src/utils/checkpointing/Resume.cpp
+++ b/src/utils/checkpointing/Resume.cpp
@@ -26,6 +26,7 @@
 #include <utils/checkpointing/ResumingManager.hpp>
 
 #include <utils/checkpointing/LinearSolverOptionsHFType.hpp>
+#include <utils/checkpointing/ParallelCheckerHFType.hpp>
 
 #include <language/utils/CheckpointResumeRepository.hpp>
 
@@ -69,6 +70,12 @@ resume()
       ExecutionStatManager::getInstance().setPreviousCumulativeElapseTime(cumulative_elapse_time);
       ExecutionStatManager::getInstance().setPreviousCumulativeTotalCPUTime(cumulative_total_cpu_time);
     }
+    {
+      HighFive::Group random_seed_group = checkpoint.getGroup("singleton/parallel_checker");
+      // Ordering is important! Must set mode before changing the tag (changing mode is not allowed if tag!=0)
+      ParallelChecker::instance().setMode(random_seed_group.getAttribute("mode").read<ParallelChecker::Mode>());
+      ParallelChecker::instance().setTag(random_seed_group.getAttribute("tag").read<size_t>());
+    }
     {
       HighFive::Group linear_solver_options_default_group =
         checkpoint.getGroup("singleton/linear_solver_options_default");
-- 
GitLab