From f30eb8e2638e1248382de163ab6410fd9ebf2d3d Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 16 Jul 2024 08:49:27 +0200
Subject: [PATCH] Add tests for IDiscreteFunctionDescriptor checkpointing

---
 tests/CMakeLists.txt                          |  1 +
 ...ckpointing_IDiscreteFunctionDescriptor.cpp | 84 +++++++++++++++++++
 2 files changed, 85 insertions(+)
 create mode 100644 tests/test_checkpointing_IDiscreteFunctionDescriptor.cpp

diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index cb564c1d6..ac5a6d7bb 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -163,6 +163,7 @@ if(PUGS_HAS_HDF5)
     test_checkpointing_HFTypes.cpp
     test_checkpointing_IBoundaryDescriptor.cpp
     test_checkpointing_IBoundaryConditionDescriptor.cpp
+    test_checkpointing_IDiscreteFunctionDescriptor.cpp
     test_checkpointing_IInterfaceDescriptor.cpp
     test_checkpointing_ItemType.cpp
     test_checkpointing_IZoneDescriptor.cpp
diff --git a/tests/test_checkpointing_IDiscreteFunctionDescriptor.cpp b/tests/test_checkpointing_IDiscreteFunctionDescriptor.cpp
new file mode 100644
index 000000000..da98859ed
--- /dev/null
+++ b/tests/test_checkpointing_IDiscreteFunctionDescriptor.cpp
@@ -0,0 +1,84 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/Messenger.hpp>
+
+#include <language/utils/DataHandler.hpp>
+#include <language/utils/EmbeddedData.hpp>
+#include <scheme/DiscreteFunctionDescriptorP0.hpp>
+#include <scheme/DiscreteFunctionDescriptorP0Vector.hpp>
+#include <scheme/IDiscreteFunctionDescriptor.hpp>
+#include <utils/checkpointing/ReadIDiscreteFunctionDescriptor.hpp>
+#include <utils/checkpointing/WriteIDiscreteFunctionDescriptor.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+TEST_CASE("checkpointing_IDiscreteFunctionDescriptor", "[utils/checkpointing]")
+{
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    HighFive::FileAccessProps fapl;
+    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+    fapl.add(HighFive::MPIOCollectiveMetadata{});
+    HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+    SECTION("IDiscreteFunctionDescriptor")
+    {
+      HighFive::Group symbol_table_group = file.createGroup("symbol_table");
+      HighFive::Group useless_group;
+
+      auto p_discrete_function_p0 = std::make_shared<DiscreteFunctionDescriptorP0>();
+      checkpointing::writeIDiscreteFunctionDescriptor("P0",
+                                                      EmbeddedData{std::make_shared<
+                                                        DataHandler<const IDiscreteFunctionDescriptor>>(
+                                                        p_discrete_function_p0)},
+                                                      file, useless_group, symbol_table_group);
+
+      auto p_discrete_function_p0_vector = std::make_shared<DiscreteFunctionDescriptorP0Vector>();
+      checkpointing::writeIDiscreteFunctionDescriptor("P0Vector",
+                                                      EmbeddedData{std::make_shared<
+                                                        DataHandler<const IDiscreteFunctionDescriptor>>(
+                                                        p_discrete_function_p0_vector)},
+                                                      file, useless_group, symbol_table_group);
+
+      file.flush();
+
+      EmbeddedData read_df_descriptor_p0 = checkpointing::readIDiscreteFunctionDescriptor("P0", symbol_table_group);
+
+      EmbeddedData read_df_descriptor_p0_vector =
+        checkpointing::readIDiscreteFunctionDescriptor("P0Vector", symbol_table_group);
+
+      auto get_value = [](const EmbeddedData& embedded_data) -> const IDiscreteFunctionDescriptor& {
+        return *dynamic_cast<const DataHandler<const IDiscreteFunctionDescriptor>&>(embedded_data.get()).data_ptr();
+      };
+
+      REQUIRE_NOTHROW(get_value(read_df_descriptor_p0));
+      REQUIRE_NOTHROW(get_value(read_df_descriptor_p0_vector));
+
+      REQUIRE_NOTHROW(dynamic_cast<const DiscreteFunctionDescriptorP0&>(get_value(read_df_descriptor_p0)));
+      REQUIRE_NOTHROW(dynamic_cast<const DiscreteFunctionDescriptorP0Vector&>(get_value(read_df_descriptor_p0_vector)));
+
+      REQUIRE(get_value(read_df_descriptor_p0).type() == DiscreteFunctionType::P0);
+      REQUIRE(get_value(read_df_descriptor_p0_vector).type() == DiscreteFunctionType::P0Vector);
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+}
-- 
GitLab