From 02dc3f42f20e8f5cb9b0fc262eb0f9326b3704fc Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Tue, 16 Jul 2024 23:34:12 +0200
Subject: [PATCH] Add tests for IWriter checkpointing

---
 src/output/WriterBase.hpp            |  21 +++++
 tests/CMakeLists.txt                 |   1 +
 tests/test_checkpointing_IWriter.cpp | 124 +++++++++++++++++++++++++++
 3 files changed, 146 insertions(+)
 create mode 100644 tests/test_checkpointing_IWriter.cpp

diff --git a/src/output/WriterBase.hpp b/src/output/WriterBase.hpp
index dcd61da2d..6de7ab558 100644
--- a/src/output/WriterBase.hpp
+++ b/src/output/WriterBase.hpp
@@ -151,6 +151,27 @@ class WriterBase : public IWriter
   virtual void _writeMesh(const MeshVariant& mesh_v) const = 0;
 
  public:
+  PUGS_INLINE
+  const std::string&
+  baseFilename() const
+  {
+    return m_base_filename;
+  }
+
+  PUGS_INLINE
+  const std::optional<PeriodManager>&
+  periodManager() const
+  {
+    return m_period_manager;
+  }
+
+  PUGS_INLINE
+  const std::optional<std::string>&
+  signature() const
+  {
+    return m_signature;
+  }
+
   void write(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const final;
 
   void writeIfNeeded(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list,
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 409d2a5b6..258fcbde8 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -167,6 +167,7 @@ if(PUGS_HAS_HDF5)
     test_checkpointing_IDiscreteFunctionDescriptor.cpp
     test_checkpointing_IInterfaceDescriptor.cpp
     test_checkpointing_ItemType.cpp
+    test_checkpointing_IWriter.cpp
     test_checkpointing_IZoneDescriptor.cpp
   )
 endif(PUGS_HAS_HDF5)
diff --git a/tests/test_checkpointing_IWriter.cpp b/tests/test_checkpointing_IWriter.cpp
new file mode 100644
index 000000000..33d00bd90
--- /dev/null
+++ b/tests/test_checkpointing_IWriter.cpp
@@ -0,0 +1,124 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/Messenger.hpp>
+
+#include <language/utils/DataHandler.hpp>
+#include <language/utils/EmbeddedData.hpp>
+#include <output/GnuplotWriter.hpp>
+#include <output/GnuplotWriter1D.hpp>
+#include <output/NamedDiscreteFunction.hpp>
+#include <output/VTKWriter.hpp>
+#include <scheme/DiscreteFunctionP0.hpp>
+#include <scheme/DiscreteFunctionVariant.hpp>
+#include <utils/checkpointing/ReadIWriter.hpp>
+#include <utils/checkpointing/WriteIWriter.hpp>
+
+#include <MeshDataBaseForTests.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+TEST_CASE("checkpointing_IWriter", "[utils/checkpointing]")
+{
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    HighFive::FileAccessProps fapl;
+    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+    fapl.add(HighFive::MPIOCollectiveMetadata{});
+    HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+    SECTION("IWriter")
+    {
+      HighFive::Group symbol_table_group = file.createGroup("symbol_table");
+      HighFive::Group useless_group;
+
+      auto p_gnuplot_writer = std::make_shared<const GnuplotWriter>("gnuplot_basename");
+      checkpointing::writeIWriter("gnuplot",
+                                  EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_gnuplot_writer)}, file,
+                                  useless_group, symbol_table_group);
+
+      auto p_gnuplot_writer_1d = std::make_shared<const GnuplotWriter1D>("gnuplot_1d_basename", 1.12);
+      p_gnuplot_writer_1d->periodManager().value().setSaveTime(2);
+      checkpointing::writeIWriter("gnuplot_1d",
+                                  EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_gnuplot_writer_1d)}, file,
+                                  useless_group, symbol_table_group);
+
+      const std::string vtk_filename = path / "vtk_example";
+
+      auto mesh_v = MeshDataBaseForTests::get().cartesian1DMesh();
+
+      DiscreteFunctionP0<double> fh{mesh_v};
+      fh.fill(0);
+      std::shared_ptr<const DiscreteFunctionVariant> discrete_function = std::make_shared<DiscreteFunctionVariant>(fh);
+
+      std::shared_ptr<const INamedDiscreteData> value =
+        std::make_shared<const NamedDiscreteFunction>(discrete_function, "fh");
+
+      auto p_vtk_writer = std::make_shared<const VTKWriter>(vtk_filename, 0.02);
+
+      p_vtk_writer->writeIfNeeded({value}, 1);
+
+      checkpointing::writeIWriter("vtk", EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_vtk_writer)}, file,
+                                  useless_group, symbol_table_group);
+
+      file.flush();
+
+      EmbeddedData read_gnuplot_writer    = checkpointing::readIWriter("gnuplot", symbol_table_group);
+      EmbeddedData read_gnuplot_writer_1d = checkpointing::readIWriter("gnuplot_1d", symbol_table_group);
+      EmbeddedData read_vtk_writer        = checkpointing::readIWriter("vtk", symbol_table_group);
+
+      auto get_value = [](const EmbeddedData& embedded_data) -> const IWriter& {
+        return *dynamic_cast<const DataHandler<const IWriter>&>(embedded_data.get()).data_ptr();
+      };
+
+      REQUIRE_NOTHROW(get_value(read_gnuplot_writer));
+      REQUIRE_NOTHROW(dynamic_cast<const GnuplotWriter&>(get_value(read_gnuplot_writer)));
+      const GnuplotWriter& read_gp_writer = dynamic_cast<const GnuplotWriter&>(get_value(read_gnuplot_writer));
+      REQUIRE(read_gp_writer.type() == IWriter::Type::gnuplot);
+      REQUIRE(read_gp_writer.baseFilename() == "gnuplot_basename");
+      REQUIRE(not read_gp_writer.periodManager().has_value());
+      REQUIRE(not read_gp_writer.signature().has_value());
+
+      REQUIRE_NOTHROW(get_value(read_gnuplot_writer_1d));
+      REQUIRE_NOTHROW(dynamic_cast<const GnuplotWriter1D&>(get_value(read_gnuplot_writer_1d)));
+      const GnuplotWriter1D& read_gp_writer_1d =
+        dynamic_cast<const GnuplotWriter1D&>(get_value(read_gnuplot_writer_1d));
+      REQUIRE(read_gp_writer_1d.type() == IWriter::Type::gnuplot_1d);
+      REQUIRE(read_gp_writer_1d.baseFilename() == "gnuplot_1d_basename");
+      REQUIRE(read_gp_writer_1d.periodManager().has_value());
+      REQUIRE(read_gp_writer_1d.periodManager().value().timePeriod() == 1.12);
+      REQUIRE(read_gp_writer_1d.periodManager().value().nextTime() == 3.12);
+      REQUIRE(not read_gp_writer_1d.signature().has_value());
+
+      REQUIRE_NOTHROW(get_value(read_vtk_writer));
+      REQUIRE_NOTHROW(dynamic_cast<const VTKWriter&>(get_value(read_vtk_writer)));
+      const VTKWriter& read_vtk = dynamic_cast<const VTKWriter&>(get_value(read_vtk_writer));
+      REQUIRE(read_vtk.type() == IWriter::Type::vtk);
+      REQUIRE(read_vtk.baseFilename() == vtk_filename);
+      REQUIRE(read_vtk.periodManager().has_value());
+      REQUIRE(read_vtk.periodManager().value().timePeriod() == 0.02);
+      REQUIRE(read_vtk.periodManager().value().nextTime() == 1.02);
+      REQUIRE(read_vtk.signature().has_value());
+      REQUIRE(read_vtk.signature().value() == p_vtk_writer->signature().value());
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+}
-- 
GitLab