From 0f5f99db9e69d22a8ed7d45787c14dfe3ea8b67d Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Thu, 9 May 2024 20:48:16 +0200
Subject: [PATCH] Add checpoint/resume for DiscreteFunctionP0Vector

---
 src/utils/checkpointing/CheckpointUtils.cpp |  3 +-
 src/utils/checkpointing/CheckpointUtils.hpp | 53 +++++++++++++++++++++
 src/utils/checkpointing/ResumeUtils.hpp     | 21 ++++++--
 3 files changed, 71 insertions(+), 6 deletions(-)

diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp
index c67514c10..4d77a25fa 100644
--- a/src/utils/checkpointing/CheckpointUtils.cpp
+++ b/src/utils/checkpointing/CheckpointUtils.cpp
@@ -499,8 +499,7 @@ writeDiscreteFunctionVariant(const std::string& symbol_name,
       } else if constexpr (is_discrete_function_P0_vector_v<DFType>) {
         using data_type = std::decay_t<typename DFType::data_type>;
         variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>));
-        throw NotImplementedError("P0Vector");
-        //        write(variable_group, "values", discrete_function.cellArrays());
+        write(variable_group, "values", discrete_function.cellArrays());
       }
     },
     discrete_function_v->discreteFunction());
diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp
index 984ff037f..491d95386 100644
--- a/src/utils/checkpointing/CheckpointUtils.hpp
+++ b/src/utils/checkpointing/CheckpointUtils.hpp
@@ -5,6 +5,7 @@
 
 #include <language/utils/SymbolTable.hpp>
 #include <mesh/CellType.hpp>
+#include <mesh/ItemArray.hpp>
 #include <mesh/ItemValue.hpp>
 #include <utils/Messenger.hpp>
 
@@ -59,6 +60,49 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar
   dataset.createAttribute("size_per_rank", size_vector);
 }
 
+template <typename DataType>
+PUGS_INLINE void
+write(HighFive::Group& group, const std::string& name, const Table<DataType>& table)
+{
+  const size_t number_of_columns = parallel::allReduceMax(table.numberOfColumns());
+  if ((table.numberOfColumns() != number_of_columns) and (table.numberOfRows() > 0)) {
+    throw UnexpectedError("table must have same number of columns in parallel");
+  }
+
+  auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; };
+
+  Array<size_t> number_of_rows_per_rank = parallel::allGather(table.numberOfRows());
+  size_t global_size                    = sum(number_of_rows_per_rank) * number_of_columns;
+
+  size_t current_offset = 0;
+  for (size_t i = 0; i < parallel::rank(); ++i) {
+    current_offset += number_of_rows_per_rank[i] * table.numberOfColumns();
+  }
+  std::vector<size_t> offset{current_offset, 0ul};
+  std::vector<size_t> count{table.numberOfRows() * table.numberOfColumns()};
+
+  using data_type = std::remove_const_t<DataType>;
+  HighFive::DataSetCreateProps properties;
+  properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)}));
+  properties.add(HighFive::Shuffle());
+  properties.add(HighFive::Deflate(3));
+
+  auto xfer_props = HighFive::DataTransferProps{};
+  xfer_props.add(HighFive::UseCollectiveIO{});
+
+  HighFive::DataSet dataset =
+    group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties);
+  dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props);
+
+  std::vector<size_t> number_of_rows_per_rank_vector;
+  for (size_t i = 0; i < number_of_rows_per_rank.size(); ++i) {
+    number_of_rows_per_rank_vector.push_back(number_of_rows_per_rank[i]);
+  }
+
+  dataset.createAttribute("number_of_rows_per_rank", number_of_rows_per_rank_vector);
+  dataset.createAttribute("number_of_columns", number_of_columns);
+}
+
 template <typename DataType, ItemType item_type, typename ConnectivityPtr>
 PUGS_INLINE void
 write(HighFive::Group& group,
@@ -68,6 +112,15 @@ write(HighFive::Group& group,
   write(group, name, item_value.arrayView());
 }
 
+template <typename DataType, ItemType item_type, typename ConnectivityPtr>
+PUGS_INLINE void
+write(HighFive::Group& group,
+      const std::string& name,
+      const ItemArray<DataType, item_type, ConnectivityPtr>& item_array)
+{
+  write(group, name, item_array.tableView());
+}
+
 void writeDiscreteFunctionVariant(const std::string& symbol_name,
                                   const EmbeddedData& embedded_data,
                                   HighFive::File& file,
diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp
index 5fcf4eb2e..2e4fb05d6 100644
--- a/src/utils/checkpointing/ResumeUtils.hpp
+++ b/src/utils/checkpointing/ResumeUtils.hpp
@@ -50,19 +50,32 @@ template <typename DataType>
 PUGS_INLINE Table<DataType>
 readTable(const HighFive::Group& group, const std::string& name)
 {
+  auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; };
+
   using data_type = std::remove_const_t<DataType>;
 
   auto dataset = group.getDataSet(name);
-  Table<DataType> table(dataset.getDimensions()[0], dataset.getDimensions()[1]);
+
+  const size_t number_of_columns = dataset.getAttribute("number_of_columns").read<size_t>();
+  const std::vector<size_t> number_of_rows_per_rank =
+    dataset.getAttribute("number_of_rows_per_rank").read<std::vector<size_t>>();
+
+  std::vector<size_t> offset{0, 0ul};
+  for (size_t i = 0; i < parallel::rank(); ++i) {
+    offset[0] += number_of_rows_per_rank[i] * number_of_columns;
+  }
+  std::vector<size_t> count{number_of_rows_per_rank[parallel::rank()]};
+
+  Table<DataType> table(number_of_rows_per_rank[parallel::rank()], number_of_columns);
   if constexpr (std::is_same_v<CellType, data_type>) {
     using base_type = std::underlying_type_t<CellType>;
-    dataset.read_raw(reinterpret_cast<base_type*>(&(table(0, 0))));
+    dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(table)));
   } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or
                        (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) {
     using base_type = typename data_type::base_type;
-    dataset.read_raw<base_type>(reinterpret_cast<base_type*>(&(table(0, 0))));
+    dataset.select(offset, count).read_raw<base_type>(reinterpret_cast<base_type*>(get_address(table)));
   } else {
-    dataset.read_raw<data_type>(&(table(0, 0)));
+    dataset.select(offset, count).read_raw<data_type>(get_address(table));
   }
 
   return table;
-- 
GitLab