From af9034891f18e5dcfffc6724449b2bc2627c7037 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 6 Sep 2024 21:06:55 +0200
Subject: [PATCH] Add tests for ItemValueVariant checkpointing

---
 tests/CMakeLists.txt                          |   1 +
 tests/test_checkpointing_ItemArrayVariant.cpp | 352 ++++++++++++++++++
 2 files changed, 353 insertions(+)
 create mode 100644 tests/test_checkpointing_ItemArrayVariant.cpp

diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 26a5dd744..7af1682ae 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -164,6 +164,7 @@ if(PUGS_HAS_HDF5)
     test_checkpointing_Connectivity.cpp
     test_checkpointing_HFTypes.cpp
     test_checkpointing_ItemArray.cpp
+    test_checkpointing_ItemArrayVariant.cpp
     test_checkpointing_ItemValue.cpp
     test_checkpointing_ItemValueVariant.cpp
     test_checkpointing_OStream.cpp
diff --git a/tests/test_checkpointing_ItemArrayVariant.cpp b/tests/test_checkpointing_ItemArrayVariant.cpp
new file mode 100644
index 000000000..06a8d8dac
--- /dev/null
+++ b/tests/test_checkpointing_ItemArrayVariant.cpp
@@ -0,0 +1,352 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/Messenger.hpp>
+
+#include <language/utils/DataHandler.hpp>
+#include <language/utils/EmbeddedData.hpp>
+#include <mesh/ItemArrayVariant.hpp>
+#include <mesh/Mesh.hpp>
+#include <utils/GlobalVariableManager.hpp>
+#include <utils/checkpointing/ReadItemArrayVariant.hpp>
+#include <utils/checkpointing/ResumingData.hpp>
+#include <utils/checkpointing/WriteItemArrayVariant.hpp>
+
+#include <MeshDataBaseForTests.hpp>
+#include <checkpointing_Connectivity_utilities.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+namespace test_only
+{
+
+template <typename DataType, ItemType item_type>
+PUGS_INLINE void
+check_is_same(const ItemArray<DataType, item_type>& reference, const EmbeddedData& e_read_data)
+{
+  auto same_table = [](const auto& a, const auto& b) -> bool {
+    bool same = true;
+    if ((a.numberOfRows() == b.numberOfRows()) and (a.numberOfColumns() == b.numberOfColumns())) {
+      for (size_t i = 0; i < a.numberOfRows(); ++i) {
+        for (size_t j = 0; j < a.numberOfColumns(); ++j) {
+          same &= (a(i, j) == b(i, j));
+        }
+      }
+    } else {
+      same = false;
+    }
+
+    if (not same) {
+      throw UnexpectedError("a!=b");
+    }
+
+    return parallel::allReduceAnd(same);
+  };
+
+  REQUIRE_NOTHROW(dynamic_cast<const DataHandler<const ItemArrayVariant>&>(e_read_data.get()));
+
+  std::shared_ptr<const ItemArrayVariant> p_new_data_v =
+    dynamic_cast<const DataHandler<const ItemArrayVariant>&>(e_read_data.get()).data_ptr();
+
+  using ItemTypeT = ItemArray<const DataType, item_type>;
+
+  ItemTypeT read_data = p_new_data_v->get<ItemTypeT>();
+
+  switch (reference.connectivity_ptr()->dimension()) {
+  case 1: {
+    REQUIRE(test_only::isSameConnectivity(dynamic_cast<const Connectivity<1>&>(*reference.connectivity_ptr()),
+                                          dynamic_cast<const Connectivity<1>&>(*read_data.connectivity_ptr())));
+    break;
+  }
+  case 2: {
+    REQUIRE(test_only::isSameConnectivity(dynamic_cast<const Connectivity<2>&>(*reference.connectivity_ptr()),
+                                          dynamic_cast<const Connectivity<2>&>(*read_data.connectivity_ptr())));
+    break;
+  }
+  case 3: {
+    REQUIRE(test_only::isSameConnectivity(dynamic_cast<const Connectivity<3>&>(*reference.connectivity_ptr()),
+                                          dynamic_cast<const Connectivity<3>&>(*read_data.connectivity_ptr())));
+    break;
+  }
+  default: {
+    throw UnexpectedError("invalid connectivity dimension");
+  }
+  }
+
+  REQUIRE(same_table(reference.tableView(), read_data.tableView()));
+}
+
+}   // namespace test_only
+
+TEST_CASE("checkpointing_ItemArrayVariant", "[utils/checkpointing]")
+{
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    HighFive::FileAccessProps fapl;
+    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+    fapl.add(HighFive::MPIOCollectiveMetadata{});
+    HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+    const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
+
+    SECTION("Connectivity")
+    {
+      using R1   = TinyVector<1>;
+      using R2   = TinyVector<2>;
+      using R3   = TinyVector<3>;
+      using R1x1 = TinyMatrix<1>;
+      using R2x2 = TinyMatrix<2>;
+      using R3x3 = TinyMatrix<3>;
+
+      HighFive::Group checkpoint_group   = file.createGroup("checkpoint");
+      HighFive::Group symbol_table_group = checkpoint_group.createGroup("symbol_table");
+
+      auto mesh_1d = MeshDataBaseForTests::get().unordered1DMesh()->get<Mesh<1>>();
+
+      CellArray<bool> cell_B_1d{mesh_1d->connectivity(), 3};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        for (size_t i = 0; i < cell_B_1d.sizeOfArrays(); ++i) {
+          cell_B_1d[cell_id][i] = (std::rand() / (RAND_MAX / mesh_1d->numberOfCells())) % 2;
+        }
+      }
+
+      CellArray<uint64_t> cell_N_1d{mesh_1d->connectivity(), 2};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        for (size_t i = 0; i < cell_N_1d.sizeOfArrays(); ++i) {
+          cell_N_1d[cell_id][i] = (std::rand() / (RAND_MAX / mesh_1d->numberOfCells()));
+        }
+      }
+
+      NodeArray<int64_t> node_Z_1d{mesh_1d->connectivity(), 1};
+      for (NodeId node_id = 0; node_id < mesh_1d->numberOfNodes(); ++node_id) {
+        for (size_t i = 0; i < node_Z_1d.sizeOfArrays(); ++i) {
+          node_Z_1d[node_id][i] = 100 * (std::rand() - RAND_MAX / 2.) / (RAND_MAX / mesh_1d->numberOfNodes());
+        }
+      }
+
+      NodeArray<double> node_R_1d{mesh_1d->connectivity(), 4};
+      for (NodeId node_id = 0; node_id < mesh_1d->numberOfNodes(); ++node_id) {
+        for (size_t i = 0; i < node_R_1d.sizeOfArrays(); ++i) {
+          node_R_1d[node_id][i] = std::rand() / (1. * RAND_MAX / mesh_1d->numberOfNodes());
+        }
+      }
+
+      CellArray<R1> cell_R1_1d{mesh_1d->connectivity(), 3};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        for (size_t i = 0; i < cell_R1_1d.sizeOfArrays(); ++i) {
+          cell_R1_1d[cell_id][i] = R1{std::rand() / (1. * RAND_MAX / mesh_1d->numberOfCells())};
+        }
+      }
+
+      NodeArray<R2> node_R2_1d{mesh_1d->connectivity(), 2};
+      for (NodeId node_id = 0; node_id < mesh_1d->numberOfNodes(); ++node_id) {
+        for (size_t i = 0; i < node_R2_1d.sizeOfArrays(); ++i) {
+          node_R2_1d[node_id][i] = R2{std::rand() / (1. * RAND_MAX / mesh_1d->numberOfNodes()),
+                                      std::rand() / (1. * RAND_MAX / mesh_1d->numberOfNodes())};
+        }
+      }
+
+      auto mesh_2d = MeshDataBaseForTests::get().hybrid2DMesh()->get<Mesh<2>>();
+
+      FaceArray<R3> face_R3_2d{mesh_2d->connectivity(), 3};
+      for (FaceId face_id = 0; face_id < mesh_2d->numberOfFaces(); ++face_id) {
+        for (size_t i = 0; i < face_R3_2d.sizeOfArrays(); ++i) {
+          face_R3_2d[face_id][i] = R3{std::rand() / (1. * RAND_MAX / mesh_2d->numberOfFaces()),
+                                      std::rand() / (1. * RAND_MAX / mesh_2d->numberOfFaces()),
+                                      std::rand() / (1. * RAND_MAX / mesh_2d->numberOfFaces())};
+        }
+      }
+
+      NodeArray<R2x2> node_R2x2_2d{mesh_2d->connectivity(), 4};
+      for (NodeId node_id = 0; node_id < mesh_2d->numberOfNodes(); ++node_id) {
+        for (size_t i = 0; i < node_R2x2_2d.sizeOfArrays(); ++i) {
+          node_R2x2_2d[node_id][i] = R2x2{std::rand() / (1. * RAND_MAX / mesh_2d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_2d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_2d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_2d->numberOfNodes())};
+        }
+      }
+
+      auto mesh_3d = MeshDataBaseForTests::get().hybrid3DMesh()->get<Mesh<3>>();
+
+      EdgeArray<R3> edge_R3_3d{mesh_3d->connectivity(), 2};
+      for (EdgeId edge_id = 0; edge_id < mesh_3d->numberOfEdges(); ++edge_id) {
+        for (size_t i = 0; i < edge_R3_3d.sizeOfArrays(); ++i) {
+          edge_R3_3d[edge_id][i] = R3{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfEdges()),
+                                      std::rand() / (1. * RAND_MAX / mesh_3d->numberOfEdges()),
+                                      std::rand() / (1. * RAND_MAX / mesh_3d->numberOfEdges())};
+        }
+      }
+
+      FaceArray<R1x1> face_R1x1_3d{mesh_3d->connectivity(), 1};
+      for (FaceId face_id = 0; face_id < mesh_3d->numberOfFaces(); ++face_id) {
+        for (size_t i = 0; i < face_R1x1_3d.sizeOfArrays(); ++i) {
+          face_R1x1_3d[face_id][i] = R1x1{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfFaces())};
+        }
+      }
+
+      NodeArray<R3x3> node_R3x3_3d{mesh_3d->connectivity(), 2};
+      for (NodeId node_id = 0; node_id < mesh_3d->numberOfNodes(); ++node_id) {
+        for (size_t i = 0; i < node_R3x3_3d.sizeOfArrays(); ++i) {
+          node_R3x3_3d[node_id][i] = R3x3{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes()),
+                                          std::rand() / (1. * RAND_MAX / mesh_3d->numberOfNodes())};
+        }
+      }
+
+      {   // Write
+        using DataHandlerT = DataHandler<const ItemArrayVariant>;
+
+        auto new_connectivity_1d = test_only::duplicateConnectivity(mesh_1d->connectivity());
+
+        CellArray<const bool> cell_B_1d_new{*new_connectivity_1d, cell_B_1d.tableView()};
+        NodeArray<const int64_t> node_Z_1d_new{*new_connectivity_1d, node_Z_1d.tableView()};
+        CellArray<const uint64_t> cell_N_1d_new{*new_connectivity_1d, cell_N_1d.tableView()};
+        NodeArray<const double> node_R_1d_new{*new_connectivity_1d, node_R_1d.tableView()};
+        CellArray<const R1> cell_R1_1d_new{*new_connectivity_1d, cell_R1_1d.tableView()};
+        NodeArray<const R2> node_R2_1d_new{*new_connectivity_1d, node_R2_1d.tableView()};
+
+        checkpointing::writeItemArrayVariant("cell_B_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(cell_B_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("cell_N_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(cell_N_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("node_Z_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(node_Z_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("node_R_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(node_R_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("cell_R1_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(cell_R1_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("node_R2_1d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(node_R2_1d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        auto new_connectivity_2d = test_only::duplicateConnectivity(mesh_2d->connectivity());
+
+        FaceArray<const R3> face_R3_2d_new{*new_connectivity_2d, face_R3_2d.tableView()};
+        NodeArray<const R2x2> node_R2x2_2d_new{*new_connectivity_2d, node_R2x2_2d.tableView()};
+
+        checkpointing::writeItemArrayVariant("face_R3_2d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(face_R3_2d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("node_R2x2_2d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(node_R2x2_2d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        auto new_connectivity_3d = test_only::duplicateConnectivity(mesh_3d->connectivity());
+
+        EdgeArray<const R3> edge_R3_3d_new{*new_connectivity_3d, edge_R3_3d.tableView()};
+        FaceArray<const R1x1> face_R1x1_3d_new{*new_connectivity_3d, face_R1x1_3d.tableView()};
+        NodeArray<const R3x3> node_R3x3_3d_new{*new_connectivity_3d, node_R3x3_3d.tableView()};
+
+        checkpointing::writeItemArrayVariant("edge_R3_3d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(edge_R3_3d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("face_R1x1_3d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(face_R1x1_3d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeItemArrayVariant("node_R3x3_3d",
+                                             EmbeddedData{std::make_shared<DataHandlerT>(
+                                               std::make_shared<const ItemArrayVariant>(node_R3x3_3d_new))},
+                                             file, checkpoint_group, symbol_table_group);
+
+        HighFive::Group global_variables_group = checkpoint_group.createGroup("singleton/global_variables");
+        global_variables_group.createAttribute("connectivity_id",
+                                               GlobalVariableManager::instance().getConnectivityId());
+        global_variables_group.createAttribute("mesh_id", GlobalVariableManager::instance().getMeshId());
+      }
+
+      // reset to reuse after resuming
+      GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
+
+      file.flush();
+
+      checkpointing::ResumingData::create();
+      checkpointing::ResumingData::instance().readData(checkpoint_group, nullptr);
+
+      GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
+      {   // Read
+        auto e_cell_B_1d = checkpointing::readItemArrayVariant("cell_B_1d", symbol_table_group);
+        test_only::check_is_same(cell_B_1d, e_cell_B_1d);
+
+        auto e_cell_N_1d = checkpointing::readItemArrayVariant("cell_N_1d", symbol_table_group);
+        test_only::check_is_same(cell_N_1d, e_cell_N_1d);
+
+        auto e_node_Z_1d = checkpointing::readItemArrayVariant("node_Z_1d", symbol_table_group);
+        test_only::check_is_same(node_Z_1d, e_node_Z_1d);
+
+        auto e_node_R_1d = checkpointing::readItemArrayVariant("node_R_1d", symbol_table_group);
+        test_only::check_is_same(node_R_1d, e_node_R_1d);
+
+        auto e_cell_R1_1d = checkpointing::readItemArrayVariant("cell_R1_1d", symbol_table_group);
+        test_only::check_is_same(cell_R1_1d, e_cell_R1_1d);
+
+        auto e_node_R2_1d = checkpointing::readItemArrayVariant("node_R2_1d", symbol_table_group);
+        test_only::check_is_same(node_R2_1d, e_node_R2_1d);
+
+        auto e_face_R3_2d = checkpointing::readItemArrayVariant("face_R3_2d", symbol_table_group);
+        test_only::check_is_same(face_R3_2d, e_face_R3_2d);
+
+        auto e_node_R2x2_2d = checkpointing::readItemArrayVariant("node_R2x2_2d", symbol_table_group);
+        test_only::check_is_same(node_R2x2_2d, e_node_R2x2_2d);
+
+        auto e_edge_R3_3d = checkpointing::readItemArrayVariant("edge_R3_3d", symbol_table_group);
+        test_only::check_is_same(edge_R3_3d, e_edge_R3_3d);
+
+        auto e_face_R1x1_3d = checkpointing::readItemArrayVariant("face_R1x1_3d", symbol_table_group);
+        test_only::check_is_same(face_R1x1_3d, e_face_R1x1_3d);
+
+        auto e_node_R3x3_3d = checkpointing::readItemArrayVariant("node_R3x3_3d", symbol_table_group);
+        test_only::check_is_same(node_R3x3_3d, e_node_R3x3_3d);
+      }
+      checkpointing::ResumingData::destroy();
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+}
-- 
GitLab