From b6bc6e016dcb65f6931a595b5cf15cd1bdf054d5 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 16 Jul 2024 00:17:01 +0200
Subject: [PATCH] Add tests for ItemType checkpointing

---
 tests/CMakeLists.txt                  |  1 +
 tests/test_checkpointing_ItemType.cpp | 92 +++++++++++++++++++++++++++
 2 files changed, 93 insertions(+)
 create mode 100644 tests/test_checkpointing_ItemType.cpp

diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index a3048e7c8..cb564c1d6 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -164,6 +164,7 @@ if(PUGS_HAS_HDF5)
     test_checkpointing_IBoundaryDescriptor.cpp
     test_checkpointing_IBoundaryConditionDescriptor.cpp
     test_checkpointing_IInterfaceDescriptor.cpp
+    test_checkpointing_ItemType.cpp
     test_checkpointing_IZoneDescriptor.cpp
   )
 endif(PUGS_HAS_HDF5)
diff --git a/tests/test_checkpointing_ItemType.cpp b/tests/test_checkpointing_ItemType.cpp
new file mode 100644
index 000000000..12b80ca4d
--- /dev/null
+++ b/tests/test_checkpointing_ItemType.cpp
@@ -0,0 +1,92 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/Messenger.hpp>
+
+#include <language/utils/DataHandler.hpp>
+#include <language/utils/EmbeddedData.hpp>
+#include <mesh/ItemType.hpp>
+#include <utils/checkpointing/ReadItemType.hpp>
+#include <utils/checkpointing/WriteItemType.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+TEST_CASE("checkpointing_ItemType", "[utils/checkpointing]")
+{
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    HighFive::FileAccessProps fapl;
+    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+    fapl.add(HighFive::MPIOCollectiveMetadata{});
+    HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+    SECTION("ItemType")
+    {
+      HighFive::Group symbol_table_group = file.createGroup("symbol_table");
+      HighFive::Group useless_group;
+
+      auto p_cell_type = std::make_shared<const ItemType>(ItemType::cell);
+      checkpointing::writeItemType("cell_type",
+                                   EmbeddedData{std::make_shared<DataHandler<const ItemType>>(p_cell_type)}, file,
+                                   useless_group, symbol_table_group);
+
+      auto p_face_type = std::make_shared<const ItemType>(ItemType::face);
+      checkpointing::writeItemType("face_type",
+                                   EmbeddedData{std::make_shared<DataHandler<const ItemType>>(p_face_type)}, file,
+                                   useless_group, symbol_table_group);
+
+      auto p_edge_type = std::make_shared<const ItemType>(ItemType::edge);
+      checkpointing::writeItemType("edge_type",
+                                   EmbeddedData{std::make_shared<DataHandler<const ItemType>>(p_edge_type)}, file,
+                                   useless_group, symbol_table_group);
+
+      auto p_node_type = std::make_shared<const ItemType>(ItemType::node);
+      checkpointing::writeItemType("node_type",
+                                   EmbeddedData{std::make_shared<DataHandler<const ItemType>>(p_node_type)}, file,
+                                   useless_group, symbol_table_group);
+
+      file.flush();
+
+      EmbeddedData read_cell_type = checkpointing::readItemType("cell_type", symbol_table_group);
+
+      EmbeddedData read_face_type = checkpointing::readItemType("face_type", symbol_table_group);
+
+      EmbeddedData read_edge_type = checkpointing::readItemType("edge_type", symbol_table_group);
+
+      EmbeddedData read_node_type = checkpointing::readItemType("node_type", symbol_table_group);
+
+      auto get_value = [](const EmbeddedData& embedded_data) -> const ItemType& {
+        return *dynamic_cast<const DataHandler<const ItemType>&>(embedded_data.get()).data_ptr();
+      };
+
+      REQUIRE_NOTHROW(get_value(read_cell_type));
+      REQUIRE_NOTHROW(get_value(read_face_type));
+      REQUIRE_NOTHROW(get_value(read_edge_type));
+      REQUIRE_NOTHROW(get_value(read_node_type));
+
+      REQUIRE(get_value(read_cell_type) == ItemType::cell);
+      REQUIRE(get_value(read_face_type) == ItemType::face);
+      REQUIRE(get_value(read_edge_type) == ItemType::edge);
+      REQUIRE(get_value(read_node_type) == ItemType::node);
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+}
-- 
GitLab