From 96543d0320702cfc5a492dff52361a73b070cf1e Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Sun, 12 May 2024 20:33:40 +0200
Subject: [PATCH] Add checkpoint/resume for ItemArrayVariant

---
 src/language/modules/MeshModule.cpp         | 13 ++++
 src/utils/checkpointing/CheckpointUtils.cpp | 41 ++++++++++
 src/utils/checkpointing/CheckpointUtils.hpp |  6 ++
 src/utils/checkpointing/ResumeUtils.cpp     | 84 +++++++++++++++++++++
 src/utils/checkpointing/ResumeUtils.hpp     |  1 +
 5 files changed, 145 insertions(+)

diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp
index 395400f28..cab959a73 100644
--- a/src/language/modules/MeshModule.cpp
+++ b/src/language/modules/MeshModule.cpp
@@ -365,6 +365,19 @@ MeshModule::registerCheckpointResume() const
                          std::function([](const std::string& symbol_name, const HighFive::Group& symbol_table_group)
                                          -> EmbeddedData { return readItemType(symbol_name, symbol_table_group); }));
 
+  CheckpointResumeRepository::instance()
+    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>>,
+                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
+                                          HighFive::File& file, HighFive::Group& checkpoint_group,
+                                          HighFive::Group& symbol_table_group) {
+                           writeItemArrayVariant(symbol_name, embedded_data, file, checkpoint_group,
+                                                 symbol_table_group);
+                         }),
+                         std::function([](const std::string& symbol_name,
+                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
+                           return readItemArrayVariant(symbol_name, symbol_table_group);
+                         }));
+
   CheckpointResumeRepository::instance()
     .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>>,
                          std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp
index bb370b7fa..5fdeb7793 100644
--- a/src/utils/checkpointing/CheckpointUtils.cpp
+++ b/src/utils/checkpointing/CheckpointUtils.cpp
@@ -8,6 +8,7 @@
 #include <language/utils/DataHandler.hpp>
 #include <language/utils/OFStream.hpp>
 #include <language/utils/OStream.hpp>
+#include <mesh/ItemArrayVariant.hpp>
 #include <mesh/ItemType.hpp>
 #include <mesh/ItemValueVariant.hpp>
 #include <mesh/Mesh.hpp>
@@ -497,6 +498,46 @@ writeIQuadratureDescriptor(const std::string& symbol_name,
   variable_group.createAttribute("quadrature_degree", iquadrature_descriptor.degree());
 }
 
+void
+writeItemArrayVariant(HighFive::Group& variable_group,
+                      std::shared_ptr<const ItemArrayVariant> item_array_variant_v,
+                      HighFive::File& file,
+                      HighFive::Group& checkpoint_group)
+{
+  variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(item_array_variant_v)>));
+
+  std::visit(
+    [&](auto&& item_array) {
+      using ItemArrayT = std::decay_t<decltype(item_array)>;
+
+      variable_group.createAttribute("item_type", ItemArrayT::item_t);
+      using data_type = std::decay_t<typename ItemArrayT::data_type>;
+      variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>));
+
+      const IConnectivity& connectivity = *item_array.connectivity_ptr();
+      variable_group.createAttribute("connectivity_id", connectivity.id());
+      writeConnectivity(connectivity, file, checkpoint_group);
+
+      write(variable_group, "arrays", item_array);
+    },
+    item_array_variant_v->itemArray());
+}
+
+void
+writeItemArrayVariant(const std::string& symbol_name,
+                      const EmbeddedData& embedded_data,
+                      HighFive::File& file,
+                      HighFive::Group& checkpoint_group,
+                      HighFive::Group& symbol_table_group)
+{
+  HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name);
+
+  std::shared_ptr<const ItemArrayVariant> item_array_variant_p =
+    dynamic_cast<const DataHandler<const ItemArrayVariant>&>(embedded_data.get()).data_ptr();
+
+  writeItemArrayVariant(variable_group, item_array_variant_p, file, checkpoint_group);
+}
+
 void
 writeItemType(const std::string& symbol_name,
               const EmbeddedData& embedded_data,
diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp
index 7edee7b2c..89030bd51 100644
--- a/src/utils/checkpointing/CheckpointUtils.hpp
+++ b/src/utils/checkpointing/CheckpointUtils.hpp
@@ -155,6 +155,12 @@ void writeIQuadratureDescriptor(const std::string& symbol_name,
                                 HighFive::Group& checkpoint_group,
                                 HighFive::Group& symbol_table_group);
 
+void writeItemArrayVariant(const std::string& symbol_name,
+                           const EmbeddedData& embedded_data,
+                           HighFive::File& file,
+                           HighFive::Group& checkpoint_group,
+                           HighFive::Group& symbol_table_group);
+
 void writeItemType(const std::string& symbol_name,
                    const EmbeddedData& embedded_data,
                    HighFive::File& file,
diff --git a/src/utils/checkpointing/ResumeUtils.cpp b/src/utils/checkpointing/ResumeUtils.cpp
index d07262779..15c557d67 100644
--- a/src/utils/checkpointing/ResumeUtils.cpp
+++ b/src/utils/checkpointing/ResumeUtils.cpp
@@ -6,6 +6,7 @@
 #include <language/utils/DataHandler.hpp>
 #include <language/utils/OFStream.hpp>
 #include <language/utils/SymbolTable.hpp>
+#include <mesh/ItemArrayVariant.hpp>
 #include <mesh/ItemValueVariant.hpp>
 #include <mesh/NamedBoundaryDescriptor.hpp>
 #include <mesh/NamedInterfaceDescriptor.hpp>
@@ -277,6 +278,89 @@ readIQuadratureDescriptor(const std::string& symbol_name, const HighFive::Group&
   return {std::make_shared<DataHandler<const IQuadratureDescriptor>>(iquadrature_descrptor)};
 }
 
+template <ItemType item_type>
+EmbeddedData
+readItemArrayVariant(const HighFive::Group& item_array_variant_group)
+{
+  const std::string data_type  = item_array_variant_group.getAttribute("data_type").read<std::string>();
+  const size_t connectivity_id = item_array_variant_group.getAttribute("connectivity_id").read<size_t>();
+
+  const IConnectivity& connectivity = *ResumingData::instance().iConnectivity(connectivity_id);
+
+  std::shared_ptr<ItemArrayVariant> p_item_array;
+
+  if (data_type == dataTypeName(ast_node_data_type_from<bool>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<bool, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<long int>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<long int, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<unsigned long int>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<unsigned long int, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<double>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<double, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<1>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyVector<1>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<2>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyVector<2>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<3>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyVector<3>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<1>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyMatrix<1>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<2>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyMatrix<2>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<3>>)) {
+    p_item_array = std::make_shared<ItemArrayVariant>(
+      readItemArray<TinyMatrix<3>, item_type>(item_array_variant_group, "arrays", connectivity));
+  } else {
+    throw UnexpectedError("unexpected discrete function data type: " + data_type);
+  }
+  return {std::make_shared<DataHandler<const ItemArrayVariant>>(p_item_array)};
+}
+
+EmbeddedData
+readItemArrayVariant(const HighFive::Group& item_array_variant_group)
+{
+  const ItemType item_type = item_array_variant_group.getAttribute("item_type").read<ItemType>();
+
+  EmbeddedData embedded_data;
+
+  switch (item_type) {
+  case ItemType::cell: {
+    embedded_data = readItemArrayVariant<ItemType::cell>(item_array_variant_group);
+    break;
+  }
+  case ItemType::face: {
+    embedded_data = readItemArrayVariant<ItemType::face>(item_array_variant_group);
+    break;
+  }
+  case ItemType::edge: {
+    embedded_data = readItemArrayVariant<ItemType::edge>(item_array_variant_group);
+    break;
+  }
+  case ItemType::node: {
+    embedded_data = readItemArrayVariant<ItemType::node>(item_array_variant_group);
+    break;
+  }
+  }
+
+  return embedded_data;
+}
+
+EmbeddedData
+readItemArrayVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
+{
+  const HighFive::Group item_array_variant_group = symbol_table_group.getGroup("embedded/" + symbol_name);
+  return readItemArrayVariant(item_array_variant_group);
+}
+
 EmbeddedData
 readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group)
 {
diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp
index 419751f5b..dbdc5611d 100644
--- a/src/utils/checkpointing/ResumeUtils.hpp
+++ b/src/utils/checkpointing/ResumeUtils.hpp
@@ -89,6 +89,7 @@ EmbeddedData readIDiscreteFunctionDescriptor(const std::string& symbol_name, con
 EmbeddedData readIInterfaceDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readINamedDiscreteData(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readIQuadratureDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
+EmbeddedData readItemArrayVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readItemValueVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
 EmbeddedData readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group);
-- 
GitLab