From d9165a03c9ef0fe7d52f17f6a33e99342f2429a7 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 16 Apr 2024 08:49:09 +0200
Subject: [PATCH] [ci-skip] Add checkpoint/resume for
 IDiscreteFunctionDescriptor

---
 src/language/modules/SchemeModule.cpp         |  16 +-
 src/utils/checkpointing/CheckpointUtils.cpp   | 166 ++++++++++--------
 src/utils/checkpointing/CheckpointUtils.hpp   |  28 +--
 .../DiscreteFunctionTypeHFType.hpp            |  14 ++
 src/utils/checkpointing/ResumeUtils.cpp       |  56 ++++--
 src/utils/checkpointing/ResumeUtils.hpp       |   5 +-
 6 files changed, 185 insertions(+), 100 deletions(-)
 create mode 100644 src/utils/checkpointing/DiscreteFunctionTypeHFType.hpp

diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index 14d45998b..f05fcafa6 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -8,6 +8,7 @@
 #include <language/modules/UnaryOperatorRegisterForVh.hpp>
 #include <language/utils/BinaryOperatorProcessorBuilder.hpp>
 #include <language/utils/BuiltinFunctionEmbedder.hpp>
+#include <language/utils/CheckpointResumeRepository.hpp>
 #include <language/utils/TypeDescriptor.hpp>
 #include <mesh/Connectivity.hpp>
 #include <mesh/IBoundaryDescriptor.hpp>
@@ -44,6 +45,8 @@
 #include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
 #include <scheme/VariableBCDescriptor.hpp>
 #include <utils/Socket.hpp>
+#include <utils/checkpointing/CheckpointUtils.hpp>
+#include <utils/checkpointing/ResumeUtils.hpp>
 
 #include <language/modules/MeshModule.hpp>
 #include <language/modules/SocketModule.hpp>
@@ -686,5 +689,16 @@ SchemeModule::registerOperators() const
 void
 SchemeModule::registerCheckpointResume() const
 {
-  throw NotImplementedError("registerCheckpointResume()");
+  CheckpointResumeRepository::instance()
+    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const IDiscreteFunctionDescriptor>>,
+                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
+                                          HighFive::File& file, HighFive::Group& checkpoint_group,
+                                          HighFive::Group& symbol_table_group) {
+                           writeIDiscreteFunctionDescriptor(symbol_name, embedded_data, file, checkpoint_group,
+                                                            symbol_table_group);
+                         }),
+                         std::function([](const std::string& symbol_name,
+                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
+                           return readIDiscreteFunctionDescriptor(symbol_name, symbol_table_group);
+                         }));
 }
diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp
index b686e89b0..e6e503d1a 100644
--- a/src/utils/checkpointing/CheckpointUtils.cpp
+++ b/src/utils/checkpointing/CheckpointUtils.cpp
@@ -1,23 +1,26 @@
 #include <utils/checkpointing/CheckpointUtils.hpp>
 
 #include <language/modules/MeshModuleTypes.hpp>
+#include <language/modules/SchemeModuleTypes.hpp>
 #include <language/utils/ASTNodeDataTypeTraits.hpp>
 #include <language/utils/DataHandler.hpp>
-#include <mesh/MeshVariant.hpp>
-#include <utils/checkpointing/IBoundaryDescriptorHFType.hpp>
-#include <utils/checkpointing/IInterfaceDescriptorHFType.hpp>
-#include <utils/checkpointing/IZoneDescriptorHFType.hpp>
-#include <utils/checkpointing/ItemTypeHFType.hpp>
-#include <utils/checkpointing/RefItemListHFType.hpp>
-
 #include <mesh/ItemType.hpp>
 #include <mesh/Mesh.hpp>
+#include <mesh/MeshVariant.hpp>
 #include <mesh/NamedBoundaryDescriptor.hpp>
 #include <mesh/NamedInterfaceDescriptor.hpp>
 #include <mesh/NamedZoneDescriptor.hpp>
 #include <mesh/NumberedBoundaryDescriptor.hpp>
 #include <mesh/NumberedInterfaceDescriptor.hpp>
 #include <mesh/NumberedZoneDescriptor.hpp>
+#include <scheme/DiscreteFunctionP0.hpp>
+#include <scheme/DiscreteFunctionP0Vector.hpp>
+#include <utils/checkpointing/DiscreteFunctionTypeHFType.hpp>
+#include <utils/checkpointing/IBoundaryDescriptorHFType.hpp>
+#include <utils/checkpointing/IInterfaceDescriptorHFType.hpp>
+#include <utils/checkpointing/IZoneDescriptorHFType.hpp>
+#include <utils/checkpointing/ItemTypeHFType.hpp>
+#include <utils/checkpointing/RefItemListHFType.hpp>
 
 template <ItemType item_type, size_t Dimension>
 void
@@ -135,61 +138,6 @@ writeConnectivity(const Connectivity<Dimension>& connectivity, HighFive::File& f
   }
 }
 
-void
-writeMesh(const std::string& symbol_name,
-          const EmbeddedData& embedded_data,
-          HighFive::File& file,
-          HighFive::Group& checkpoint_group,
-          HighFive::Group& symbol_table_group)
-{
-  HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name);
-
-  std::shared_ptr<const MeshVariant> mesh_v =
-    dynamic_cast<const DataHandler<const MeshVariant>&>(embedded_data.get()).data_ptr();
-
-  variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(mesh_v)>));
-  variable_group.createAttribute("id", mesh_v->id());
-
-  std::string mesh_group_name = "mesh/" + std::to_string(mesh_v->id());
-  if (not checkpoint_group.exist(mesh_group_name)) {
-    bool linked = false;
-    for (auto group_name : file.listObjectNames()) {
-      if (file.exist(group_name + "/" + mesh_group_name)) {
-        checkpoint_group.createHardLink(mesh_group_name, file.getGroup(group_name + "/" + mesh_group_name));
-        linked = true;
-        break;
-      }
-    }
-
-    if (not linked) {
-      HighFive::Group mesh_group = checkpoint_group.createGroup(mesh_group_name);
-      mesh_group.createAttribute("connectivity", mesh_v->connectivity().id());
-      std::visit(
-        [&](auto&& mesh) {
-          using MeshType = mesh_type_t<decltype(mesh)>;
-          if constexpr (is_polygonal_mesh_v<MeshType>) {
-            mesh_group.createAttribute("id", mesh->id());
-            mesh_group.createAttribute("type", std::string{"polygonal"});
-            mesh_group.createAttribute("dimension", mesh->dimension());
-            write(mesh_group, "xr", mesh->xr());
-          } else {
-            throw UnexpectedError("unexpected mesh type");
-          }
-        },
-        mesh_v->variant());
-    }
-  }
-
-  std::visit(
-    [&](auto&& mesh) {
-      using MeshType = mesh_type_t<decltype(mesh)>;
-      if constexpr (is_polygonal_mesh_v<MeshType>) {
-        writeConnectivity(mesh->connectivity(), file, checkpoint_group);
-      }
-    },
-    mesh_v->variant());
-}
-
 void
 writeIBoundaryDescriptor(const std::string& symbol_name,
                          const EmbeddedData& embedded_data,
@@ -223,6 +171,25 @@ writeIBoundaryDescriptor(const std::string& symbol_name,
   }
 }
 
+void
+writeIDiscreteFunctionDescriptor(const std::string& symbol_name,
+                                 const EmbeddedData& embedded_data,
+                                 HighFive::File&,
+                                 HighFive::Group&,
+                                 HighFive::Group& symbol_table_group)
+{
+  HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name);
+
+  std::shared_ptr<const IDiscreteFunctionDescriptor> idiscrete_function_desriptor_p =
+    dynamic_cast<const DataHandler<const IDiscreteFunctionDescriptor>&>(embedded_data.get()).data_ptr();
+
+  const IDiscreteFunctionDescriptor& idiscrete_function_descriptor = *idiscrete_function_desriptor_p;
+
+  variable_group.createAttribute("type",
+                                 dataTypeName(ast_node_data_type_from<decltype(idiscrete_function_desriptor_p)>));
+  variable_group.createAttribute("discrete_function_type", idiscrete_function_descriptor.type());
+}
+
 void
 writeIInterfaceDescriptor(const std::string& symbol_name,
                           const EmbeddedData& embedded_data,
@@ -256,6 +223,24 @@ writeIInterfaceDescriptor(const std::string& symbol_name,
   }
 }
 
+void
+writeItemType(const std::string& symbol_name,
+              const EmbeddedData& embedded_data,
+              HighFive::File&,
+              HighFive::Group&,
+              HighFive::Group& symbol_table_group)
+{
+  HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name);
+
+  std::shared_ptr<const ItemType> item_type_p =
+    dynamic_cast<const DataHandler<const ItemType>&>(embedded_data.get()).data_ptr();
+
+  const ItemType& item_type = *item_type_p;
+
+  variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(item_type_p)>));
+  variable_group.createAttribute("item_type", item_type);
+}
+
 void
 writeIZoneDescriptor(const std::string& symbol_name,
                      const EmbeddedData& embedded_data,
@@ -289,19 +274,56 @@ writeIZoneDescriptor(const std::string& symbol_name,
 }
 
 void
-writeItemType(const std::string& symbol_name,
-              const EmbeddedData& embedded_data,
-              HighFive::File&,
-              HighFive::Group&,
-              HighFive::Group& symbol_table_group)
+writeMesh(const std::string& symbol_name,
+          const EmbeddedData& embedded_data,
+          HighFive::File& file,
+          HighFive::Group& checkpoint_group,
+          HighFive::Group& symbol_table_group)
 {
   HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name);
 
-  std::shared_ptr<const ItemType> item_type_p =
-    dynamic_cast<const DataHandler<const ItemType>&>(embedded_data.get()).data_ptr();
+  std::shared_ptr<const MeshVariant> mesh_v =
+    dynamic_cast<const DataHandler<const MeshVariant>&>(embedded_data.get()).data_ptr();
 
-  const ItemType& item_type = *item_type_p;
+  variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(mesh_v)>));
+  variable_group.createAttribute("id", mesh_v->id());
 
-  variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(item_type_p)>));
-  variable_group.createAttribute("item_type", item_type);
+  std::string mesh_group_name = "mesh/" + std::to_string(mesh_v->id());
+  if (not checkpoint_group.exist(mesh_group_name)) {
+    bool linked = false;
+    for (auto group_name : file.listObjectNames()) {
+      if (file.exist(group_name + "/" + mesh_group_name)) {
+        checkpoint_group.createHardLink(mesh_group_name, file.getGroup(group_name + "/" + mesh_group_name));
+        linked = true;
+        break;
+      }
+    }
+
+    if (not linked) {
+      HighFive::Group mesh_group = checkpoint_group.createGroup(mesh_group_name);
+      mesh_group.createAttribute("connectivity", mesh_v->connectivity().id());
+      std::visit(
+        [&](auto&& mesh) {
+          using MeshType = mesh_type_t<decltype(mesh)>;
+          if constexpr (is_polygonal_mesh_v<MeshType>) {
+            mesh_group.createAttribute("id", mesh->id());
+            mesh_group.createAttribute("type", std::string{"polygonal"});
+            mesh_group.createAttribute("dimension", mesh->dimension());
+            write(mesh_group, "xr", mesh->xr());
+          } else {
+            throw UnexpectedError("unexpected mesh type");
+          }
+        },
+        mesh_v->variant());
+    }
+  }
+
+  std::visit(
+    [&](auto&& mesh) {
+      using MeshType = mesh_type_t<decltype(mesh)>;
+      if constexpr (is_polygonal_mesh_v<MeshType>) {
+        writeConnectivity(mesh->connectivity(), file, checkpoint_group);
+      }
+    },
+    mesh_v->variant());
 }
diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp
index dc4afe88f..e5d91a20e 100644
--- a/src/utils/checkpointing/CheckpointUtils.hpp
+++ b/src/utils/checkpointing/CheckpointUtils.hpp
@@ -42,18 +42,24 @@ write(HighFive::Group& group,
   write(group, name, item_value.arrayView());
 }
 
-void writeMesh(const std::string& symbol_name,
-               const EmbeddedData& embedded_data,
-               HighFive::File& file,
-               HighFive::Group& checkpoint_group,
-               HighFive::Group& symbol_table_group);
-
 void writeIBoundaryDescriptor(const std::string& symbol_name,
                               const EmbeddedData& embedded_data,
                               HighFive::File& file,
                               HighFive::Group& checkpoint_group,
                               HighFive::Group& symbol_table_group);
 
+void writeIDiscreteFunctionDescriptor(const std::string& symbol_name,
+                                      const EmbeddedData& embedded_data,
+                                      HighFive::File& file,
+                                      HighFive::Group& checkpoint_group,
+                                      HighFive::Group& symbol_table_group);
+
+void writeItemType(const std::string& symbol_name,
+                   const EmbeddedData& embedded_data,
+                   HighFive::File& file,
+                   HighFive::Group& checkpoint_group,
+                   HighFive::Group& symbol_table_group);
+
 void writeIInterfaceDescriptor(const std::string& symbol_name,
                                const EmbeddedData& embedded_data,
                                HighFive::File& file,
@@ -66,10 +72,10 @@ void writeIZoneDescriptor(const std::string& symbol_name,
                           HighFive::Group& checkpoint_group,
                           HighFive::Group& symbol_table_group);
 
-void writeItemType(const std::string& symbol_name,
-                   const EmbeddedData& embedded_data,
-                   HighFive::File& file,
-                   HighFive::Group& checkpoint_group,
-                   HighFive::Group& symbol_table_group);
+void writeMesh(const std::string& symbol_name,
+               const EmbeddedData& embedded_data,
+               HighFive::File& file,
+               HighFive::Group& checkpoint_group,
+               HighFive::Group& symbol_table_group);
 
 #endif   // CHECKPOINT_UTILS_HPP
diff --git a/src/utils/checkpointing/DiscreteFunctionTypeHFType.hpp b/src/utils/checkpointing/DiscreteFunctionTypeHFType.hpp
new file mode 100644
index 000000000..ae7e1cc90
--- /dev/null
+++ b/src/utils/checkpointing/DiscreteFunctionTypeHFType.hpp
@@ -0,0 +1,14 @@
+#ifndef DISCRETE_FUNCTION_TYPE_HF_TYPE_HPP
+#define DISCRETE_FUNCTION_TYPE_HF_TYPE_HPP
+
+#include <scheme/DiscreteFunctionType.hpp>
+#include <utils/checkpointing/CheckpointUtils.hpp>
+
+HighFive::EnumType<DiscreteFunctionType> PUGS_INLINE
+create_enum_discrete_function_type()
+{
+  return {{"P0", DiscreteFunctionType::P0}, {"P0Vector", DiscreteFunctionType::P0Vector}};
+}
+HIGHFIVE_REGISTER_TYPE(DiscreteFunctionType, create_enum_discrete_function_type);
+
+#endif   // DISCRETE_FUNCTION_TYPE_HF_TYPE_HPP
diff --git a/src/utils/checkpointing/ResumeUtils.cpp b/src/utils/checkpointing/ResumeUtils.cpp
index e5483347b..7dc506aef 100644
--- a/src/utils/checkpointing/ResumeUtils.cpp
+++ b/src/utils/checkpointing/ResumeUtils.cpp
@@ -8,22 +8,15 @@
 #include <mesh/NumberedBoundaryDescriptor.hpp>
 #include <mesh/NumberedInterfaceDescriptor.hpp>
 #include <mesh/NumberedZoneDescriptor.hpp>
+#include <scheme/DiscreteFunctionDescriptorP0.hpp>
+#include <scheme/DiscreteFunctionDescriptorP0Vector.hpp>
+#include <utils/checkpointing/DiscreteFunctionTypeHFType.hpp>
 #include <utils/checkpointing/IBoundaryDescriptorHFType.hpp>
 #include <utils/checkpointing/IInterfaceDescriptorHFType.hpp>
 #include <utils/checkpointing/IZoneDescriptorHFType.hpp>
 #include <utils/checkpointing/ItemTypeHFType.hpp>
 #include <utils/checkpointing/ResumingData.hpp>
 
-EmbeddedData
-readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
-{
-  const HighFive::Group mesh_group = symbol_table_group.getGroup("embedded/" + symbol_name);
-
-  const size_t mesh_id = mesh_group.getAttribute("id").read<uint64_t>();
-
-  return EmbeddedData{std::make_shared<DataHandler<const MeshVariant>>(ResumingData::instance().meshVariant(mesh_id))};
-}
-
 EmbeddedData
 readIBoundaryDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
 {
@@ -53,6 +46,31 @@ readIBoundaryDescriptor(const std::string& symbol_name, const HighFive::Group& s
   return embedded_data;
 }
 
+EmbeddedData
+readIDiscreteFunctionDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
+{
+  const HighFive::Group idiscrete_function_decriptor_group = symbol_table_group.getGroup("embedded/" + symbol_name);
+  const DiscreteFunctionType discrete_function_type =
+    idiscrete_function_decriptor_group.getAttribute("discrete_function_type").read<DiscreteFunctionType>();
+
+  EmbeddedData embedded_data;
+
+  switch (discrete_function_type) {
+  case DiscreteFunctionType::P0: {
+    embedded_data = {std::make_shared<DataHandler<const IDiscreteFunctionDescriptor>>(
+      std::make_shared<const DiscreteFunctionDescriptorP0>())};
+    break;
+  }
+  case DiscreteFunctionType::P0Vector: {
+    embedded_data = {std::make_shared<DataHandler<const IDiscreteFunctionDescriptor>>(
+      std::make_shared<const DiscreteFunctionDescriptorP0Vector>())};
+    break;
+  }
+  }
+
+  return embedded_data;
+}
+
 EmbeddedData
 readIInterfaceDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
 {
@@ -82,6 +100,15 @@ readIInterfaceDescriptor(const std::string& symbol_name, const HighFive::Group&
   return embedded_data;
 }
 
+EmbeddedData
+readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
+{
+  const HighFive::Group item_type_group = symbol_table_group.getGroup("embedded/" + symbol_name);
+  const ItemType item_type              = item_type_group.getAttribute("item_type").read<ItemType>();
+
+  return {std::make_shared<DataHandler<const ItemType>>(std::make_shared<const ItemType>(item_type))};
+}
+
 EmbeddedData
 readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
 {
@@ -112,10 +139,11 @@ readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbo
 }
 
 EmbeddedData
-readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
+readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
 {
-  const HighFive::Group item_type_group = symbol_table_group.getGroup("embedded/" + symbol_name);
-  const ItemType item_type              = item_type_group.getAttribute("item_type").read<ItemType>();
+  const HighFive::Group mesh_group = symbol_table_group.getGroup("embedded/" + symbol_name);
 
-  return {std::make_shared<DataHandler<const ItemType>>(std::make_shared<const ItemType>(item_type))};
+  const size_t mesh_id = mesh_group.getAttribute("id").read<uint64_t>();
+
+  return EmbeddedData{std::make_shared<DataHandler<const MeshVariant>>(ResumingData::instance().meshVariant(mesh_id))};
 }
diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp
index 2c0b0d486..c8797804e 100644
--- a/src/utils/checkpointing/ResumeUtils.hpp
+++ b/src/utils/checkpointing/ResumeUtils.hpp
@@ -28,10 +28,11 @@ read(const HighFive::Group& group, const std::string& name)
   return array;
 }
 
-EmbeddedData readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readIBoundaryDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
+EmbeddedData readIDiscreteFunctionDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readIInterfaceDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
-EmbeddedData readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
+EmbeddedData readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
+EmbeddedData readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 
 #endif   // RESUME_UTILS_HPP
-- 
GitLab