From 6359571d866ab701cb9e4df22553be720f317d94 Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Sat, 7 Sep 2024 23:17:13 +0200
Subject: [PATCH] Add tests for DiscreteFunctionVariant checkpointing

---
 .../ReadDiscreteFunctionVariant.cpp           |   4 +
 tests/CMakeLists.txt                          |   1 +
 ..._checkpointing_DiscreteFunctionVariant.cpp | 330 ++++++++++++++++++
 3 files changed, 335 insertions(+)
 create mode 100644 tests/test_checkpointing_DiscreteFunctionVariant.cpp

diff --git a/src/utils/checkpointing/ReadDiscreteFunctionVariant.cpp b/src/utils/checkpointing/ReadDiscreteFunctionVariant.cpp
index 0fdb68c7d..15efc050a 100644
--- a/src/utils/checkpointing/ReadDiscreteFunctionVariant.cpp
+++ b/src/utils/checkpointing/ReadDiscreteFunctionVariant.cpp
@@ -66,7 +66,9 @@ readDiscreteFunctionVariant(const HighFive::Group& discrete_function_group)
                                                                                              "values",
                                                                                              mesh_v->connectivity())));
     } else {
+      // LCOV_EXCL_START
       throw UnexpectedError("unexpected discrete function data type: " + data_type);
+      // LCOV_EXCL_STOP
     }
     break;
   }
@@ -77,7 +79,9 @@ readDiscreteFunctionVariant(const HighFive::Group& discrete_function_group)
                                                readItemArray<double, ItemType::cell>(discrete_function_group, "values",
                                                                                      mesh_v->connectivity())));
     } else {
+      // LCOV_EXCL_START
       throw UnexpectedError("unexpected discrete function vector data type: " + data_type);
+      // LCOV_EXCL_STOP
     }
     break;
   }
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index aea0b508b..d83622ee1 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -162,6 +162,7 @@ if(PUGS_HAS_HDF5)
   list(APPEND checkpointing_TESTS
     test_checkpointing_Array.cpp
     test_checkpointing_Connectivity.cpp
+    test_checkpointing_DiscreteFunctionVariant.cpp
     test_checkpointing_HFTypes.cpp
     test_checkpointing_ItemArray.cpp
     test_checkpointing_ItemArrayVariant.cpp
diff --git a/tests/test_checkpointing_DiscreteFunctionVariant.cpp b/tests/test_checkpointing_DiscreteFunctionVariant.cpp
new file mode 100644
index 000000000..6dbcbafad
--- /dev/null
+++ b/tests/test_checkpointing_DiscreteFunctionVariant.cpp
@@ -0,0 +1,330 @@
+#include <catch2/catch_test_macros.hpp>
+#include <catch2/matchers/catch_matchers_all.hpp>
+
+#include <utils/Messenger.hpp>
+
+#include <language/utils/DataHandler.hpp>
+#include <language/utils/EmbeddedData.hpp>
+#include <mesh/Mesh.hpp>
+#include <scheme/DiscreteFunctionVariant.hpp>
+#include <utils/GlobalVariableManager.hpp>
+#include <utils/checkpointing/ReadDiscreteFunctionVariant.hpp>
+#include <utils/checkpointing/ResumingData.hpp>
+#include <utils/checkpointing/WriteDiscreteFunctionVariant.hpp>
+
+#include <MeshDataBaseForTests.hpp>
+#include <checkpointing_Mesh_utilities.hpp>
+
+#include <filesystem>
+
+// clazy:excludeall=non-pod-global-static
+
+namespace test_only
+{
+
+template <typename DataType>
+PUGS_INLINE void
+DiscreteFunctionVariant_check_is_same(const DiscreteFunctionP0<DataType>& reference, const EmbeddedData& e_read_data)
+{
+  auto same_value = [](const auto& a, const auto& b) -> bool {
+    bool same = true;
+    for (size_t i = 0; i < a.size(); ++i) {
+      same &= (a[i] == b[i]);
+    }
+    return parallel::allReduceAnd(same);
+  };
+
+  REQUIRE_NOTHROW(dynamic_cast<const DataHandler<const DiscreteFunctionVariant>&>(e_read_data.get()));
+
+  std::shared_ptr<const DiscreteFunctionVariant> p_new_data_v =
+    dynamic_cast<const DataHandler<const DiscreteFunctionVariant>&>(e_read_data.get()).data_ptr();
+
+  using DiscreteFunctionT = DiscreteFunctionP0<const DataType>;
+
+  DiscreteFunctionT read_data = p_new_data_v->get<DiscreteFunctionT>();
+
+  REQUIRE(test_only::isSameMesh(read_data.meshVariant(), reference.meshVariant()));
+
+  REQUIRE(same_value(reference.cellValues().arrayView(), read_data.cellValues().arrayView()));
+}
+
+template <typename DataType>
+PUGS_INLINE void
+DiscreteFunctionVariant_check_is_same(const DiscreteFunctionP0Vector<DataType>& reference,
+                                      const EmbeddedData& e_read_data)
+{
+  auto same_value = [](const auto& a, const auto& b) -> bool {
+    bool same = true;
+    for (size_t i = 0; i < a.numberOfRows(); ++i) {
+      for (size_t j = 0; j < a.numberOfColumns(); ++j) {
+        same &= (a(i, j) == b(i, j));
+      }
+    }
+    return parallel::allReduceAnd(same);
+  };
+
+  REQUIRE_NOTHROW(dynamic_cast<const DataHandler<const DiscreteFunctionVariant>&>(e_read_data.get()));
+
+  std::shared_ptr<const DiscreteFunctionVariant> p_new_data_v =
+    dynamic_cast<const DataHandler<const DiscreteFunctionVariant>&>(e_read_data.get()).data_ptr();
+
+  using DiscreteFunctionT = DiscreteFunctionP0Vector<const DataType>;
+
+  DiscreteFunctionT read_data = p_new_data_v->get<DiscreteFunctionT>();
+
+  REQUIRE(test_only::isSameMesh(read_data.meshVariant(), reference.meshVariant()));
+
+  REQUIRE(same_value(reference.cellArrays().tableView(), read_data.cellArrays().tableView()));
+}
+
+}   // namespace test_only
+
+TEST_CASE("checkpointing_DiscreteFunctionVariant", "[utils/checkpointing]")
+{
+  std::string tmp_dirname;
+  {
+    {
+      if (parallel::rank() == 0) {
+        tmp_dirname = [&]() -> std::string {
+          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
+          return std::string{mkdtemp(&temp_filename[0])};
+        }();
+      }
+      parallel::broadcast(tmp_dirname, 0);
+    }
+    std::filesystem::path path = tmp_dirname;
+    const std::string filename = path / "checkpoint.h5";
+
+    HighFive::FileAccessProps fapl;
+    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
+    fapl.add(HighFive::MPIOCollectiveMetadata{});
+    HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl);
+
+    const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();
+    const size_t initial_mesh_id         = GlobalVariableManager::instance().getMeshId();
+
+    SECTION("Mesh")
+    {
+      using R1   = TinyVector<1>;
+      using R2   = TinyVector<2>;
+      using R3   = TinyVector<3>;
+      using R1x1 = TinyMatrix<1>;
+      using R2x2 = TinyMatrix<2>;
+      using R3x3 = TinyMatrix<3>;
+
+      HighFive::Group checkpoint_group   = file.createGroup("checkpoint");
+      HighFive::Group symbol_table_group = checkpoint_group.createGroup("symbol_table");
+
+      auto mesh_1d = MeshDataBaseForTests::get().unordered1DMesh()->get<Mesh<1>>();
+
+      DiscreteFunctionP0<double> df_R_1d{mesh_1d};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        df_R_1d[cell_id] = std::rand() / (1. * RAND_MAX / mesh_1d->numberOfCells());
+      }
+
+      DiscreteFunctionP0<R1> df_R1_1d{mesh_1d};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        df_R1_1d[cell_id] = R1{std::rand() / (1. * RAND_MAX / mesh_1d->numberOfCells())};
+      }
+
+      DiscreteFunctionP0<R2> df_R2_1d{mesh_1d};
+      for (CellId cell_id = 0; cell_id < mesh_1d->numberOfCells(); ++cell_id) {
+        df_R2_1d[cell_id] = R2{std::rand() / (1. * RAND_MAX / mesh_1d->numberOfCells()),
+                               std::rand() / (1. * RAND_MAX / mesh_1d->numberOfCells())};
+      }
+
+      auto mesh_2d = MeshDataBaseForTests::get().hybrid2DMesh()->get<Mesh<2>>();
+
+      DiscreteFunctionP0<R3> df_R3_2d{mesh_2d};
+      for (CellId cell_id = 0; cell_id < mesh_2d->numberOfCells(); ++cell_id) {
+        df_R3_2d[cell_id] = R3{std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells()),
+                               std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells()),
+                               std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells())};
+      }
+
+      DiscreteFunctionP0<R2x2> df_R2x2_2d{mesh_2d};
+      for (CellId cell_id = 0; cell_id < mesh_2d->numberOfCells(); ++cell_id) {
+        df_R2x2_2d[cell_id] = R2x2{std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells())};
+      }
+
+      DiscreteFunctionP0Vector<double> dfv_R_2d{mesh_2d, 3};
+      for (CellId cell_id = 0; cell_id < mesh_2d->numberOfCells(); ++cell_id) {
+        for (size_t i = 0; i < dfv_R_2d.size(); ++i) {
+          dfv_R_2d[cell_id][i] = std::rand() / (1. * RAND_MAX / mesh_2d->numberOfCells());
+        }
+      }
+
+      auto mesh_3d = MeshDataBaseForTests::get().hybrid3DMesh()->get<Mesh<3>>();
+
+      DiscreteFunctionP0<R3> df_R3_3d{mesh_3d};
+      for (CellId cell_id = 0; cell_id < mesh_3d->numberOfCells(); ++cell_id) {
+        df_R3_3d[cell_id] = R3{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                               std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                               std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells())};
+      }
+
+      DiscreteFunctionP0<R1x1> df_R1x1_3d{mesh_3d};
+      for (CellId cell_id = 0; cell_id < mesh_3d->numberOfCells(); ++cell_id) {
+        df_R1x1_3d[cell_id] = R1x1{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells())};
+      }
+
+      DiscreteFunctionP0<R3x3> df_R3x3_3d{mesh_3d};
+      for (CellId cell_id = 0; cell_id < mesh_3d->numberOfCells(); ++cell_id) {
+        df_R3x3_3d[cell_id] = R3x3{std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells()),
+                                   std::rand() / (1. * RAND_MAX / mesh_3d->numberOfCells())};
+      }
+
+      {   // Write
+        using DataHandlerT = DataHandler<const DiscreteFunctionVariant>;
+
+        auto new_mesh_1d_v              = test_only::duplicateMesh(std::make_shared<MeshVariant>(mesh_1d));
+        auto new_mesh_1d                = new_mesh_1d_v->get<const Mesh<1>>();
+        const auto& new_connectivity_1d = new_mesh_1d->connectivity();
+
+        DiscreteFunctionP0<const double> df_R_1d_new{new_mesh_1d_v,
+                                                     CellValue<const double>{new_connectivity_1d,
+                                                                             df_R_1d.cellValues().arrayView()}};
+        DiscreteFunctionP0<const R1> df_R1_1d_new{new_mesh_1d_v,
+                                                  CellValue<const R1>{new_connectivity_1d,
+                                                                      df_R1_1d.cellValues().arrayView()}};
+        DiscreteFunctionP0<const R2> df_R2_1d_new{new_mesh_1d_v,
+                                                  CellValue<const R2>{new_connectivity_1d,
+                                                                      df_R2_1d.cellValues().arrayView()}};
+
+        checkpointing::writeDiscreteFunctionVariant("df_R_1d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R_1d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("df_R1_1d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R1_1d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("df_R2_1d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R2_1d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        auto new_mesh_2d_v              = test_only::duplicateMesh(std::make_shared<MeshVariant>(mesh_2d));
+        auto new_mesh_2d                = new_mesh_2d_v->get<const Mesh<2>>();
+        const auto& new_connectivity_2d = new_mesh_2d->connectivity();
+
+        DiscreteFunctionP0<const R3> df_R3_2d_new{new_mesh_2d_v,
+                                                  CellValue<const R3>{new_connectivity_2d,
+                                                                      df_R3_2d.cellValues().arrayView()}};
+        DiscreteFunctionP0<const R2x2> df_R2x2_2d_new{new_mesh_2d_v,
+                                                      CellValue<const R2x2>{new_connectivity_2d,
+                                                                            df_R2x2_2d.cellValues().arrayView()}};
+
+        DiscreteFunctionP0Vector<const double> dfv_R_2d_new{new_mesh_2d_v,
+                                                            CellArray<const double>{new_connectivity_2d,
+                                                                                    dfv_R_2d.cellArrays().tableView()}};
+
+        checkpointing::writeDiscreteFunctionVariant("df_R3_2d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R3_2d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("df_R2x2_2d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R2x2_2d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("dfv_R_2d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(dfv_R_2d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        auto new_mesh_3d_v              = test_only::duplicateMesh(std::make_shared<MeshVariant>(mesh_3d));
+        auto new_mesh_3d                = new_mesh_3d_v->get<const Mesh<3>>();
+        const auto& new_connectivity_3d = new_mesh_3d->connectivity();
+
+        DiscreteFunctionP0<const R3> df_R3_3d_new{new_mesh_3d, CellValue<const R3>{new_connectivity_3d,
+                                                                                   df_R3_3d.cellValues().arrayView()}};
+        DiscreteFunctionP0<const R1x1> df_R1x1_3d_new{new_mesh_3d,
+                                                      CellValue<const R1x1>{new_connectivity_3d,
+                                                                            df_R1x1_3d.cellValues().arrayView()}};
+        DiscreteFunctionP0<const R3x3> df_R3x3_3d_new{new_mesh_3d,
+                                                      CellValue<const R3x3>{new_connectivity_3d,
+                                                                            df_R3x3_3d.cellValues().arrayView()}};
+
+        checkpointing::writeDiscreteFunctionVariant("df_R3_3d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R3_3d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("df_R1x1_3d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R1x1_3d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        checkpointing::writeDiscreteFunctionVariant("df_R3x3_3d",
+                                                    EmbeddedData{std::make_shared<DataHandlerT>(
+                                                      std::make_shared<const DiscreteFunctionVariant>(df_R3x3_3d_new))},
+                                                    file, checkpoint_group, symbol_table_group);
+
+        HighFive::Group global_variables_group = checkpoint_group.createGroup("singleton/global_variables");
+        global_variables_group.createAttribute("connectivity_id",
+                                               GlobalVariableManager::instance().getConnectivityId());
+        global_variables_group.createAttribute("mesh_id", GlobalVariableManager::instance().getMeshId());
+      }
+
+      // reset to reuse after resuming
+      GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
+      GlobalVariableManager::instance().setMeshId(initial_mesh_id);
+
+      file.flush();
+
+      checkpointing::ResumingData::create();
+      checkpointing::ResumingData::instance().readData(checkpoint_group, nullptr);
+
+      GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);
+      GlobalVariableManager::instance().setMeshId(initial_mesh_id);
+      {   // Read
+        auto e_df_R_1d = checkpointing::readDiscreteFunctionVariant("df_R_1d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R_1d, e_df_R_1d);
+
+        auto e_df_R1_1d = checkpointing::readDiscreteFunctionVariant("df_R1_1d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R1_1d, e_df_R1_1d);
+
+        auto e_df_R2_1d = checkpointing::readDiscreteFunctionVariant("df_R2_1d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R2_1d, e_df_R2_1d);
+
+        auto e_df_R3_2d = checkpointing::readDiscreteFunctionVariant("df_R3_2d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R3_2d, e_df_R3_2d);
+
+        auto e_df_R2x2_2d = checkpointing::readDiscreteFunctionVariant("df_R2x2_2d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R2x2_2d, e_df_R2x2_2d);
+
+        auto e_dfv_R_2d = checkpointing::readDiscreteFunctionVariant("dfv_R_2d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(dfv_R_2d, e_dfv_R_2d);
+
+        auto e_df_R3_3d = checkpointing::readDiscreteFunctionVariant("df_R3_3d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R3_3d, e_df_R3_3d);
+
+        auto e_df_R1x1_3d = checkpointing::readDiscreteFunctionVariant("df_R1x1_3d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R1x1_3d, e_df_R1x1_3d);
+
+        auto e_df_R3x3_3d = checkpointing::readDiscreteFunctionVariant("df_R3x3_3d", symbol_table_group);
+        test_only::DiscreteFunctionVariant_check_is_same(df_R3x3_3d, e_df_R3x3_3d);
+      }
+      checkpointing::ResumingData::destroy();
+    }
+  }
+
+  parallel::barrier();
+  if (parallel::rank() == 0) {
+    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
+  }
+}
-- 
GitLab