From 02dc3f42f20e8f5cb9b0fc262eb0f9326b3704fc Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Tue, 16 Jul 2024 23:34:12 +0200 Subject: [PATCH] Add tests for IWriter checkpointing --- src/output/WriterBase.hpp | 21 +++++ tests/CMakeLists.txt | 1 + tests/test_checkpointing_IWriter.cpp | 124 +++++++++++++++++++++++++++ 3 files changed, 146 insertions(+) create mode 100644 tests/test_checkpointing_IWriter.cpp diff --git a/src/output/WriterBase.hpp b/src/output/WriterBase.hpp index dcd61da2d..6de7ab558 100644 --- a/src/output/WriterBase.hpp +++ b/src/output/WriterBase.hpp @@ -151,6 +151,27 @@ class WriterBase : public IWriter virtual void _writeMesh(const MeshVariant& mesh_v) const = 0; public: + PUGS_INLINE + const std::string& + baseFilename() const + { + return m_base_filename; + } + + PUGS_INLINE + const std::optional<PeriodManager>& + periodManager() const + { + return m_period_manager; + } + + PUGS_INLINE + const std::optional<std::string>& + signature() const + { + return m_signature; + } + void write(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const final; void writeIfNeeded(const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 409d2a5b6..258fcbde8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -167,6 +167,7 @@ if(PUGS_HAS_HDF5) test_checkpointing_IDiscreteFunctionDescriptor.cpp test_checkpointing_IInterfaceDescriptor.cpp test_checkpointing_ItemType.cpp + test_checkpointing_IWriter.cpp test_checkpointing_IZoneDescriptor.cpp ) endif(PUGS_HAS_HDF5) diff --git a/tests/test_checkpointing_IWriter.cpp b/tests/test_checkpointing_IWriter.cpp new file mode 100644 index 000000000..33d00bd90 --- /dev/null +++ b/tests/test_checkpointing_IWriter.cpp @@ -0,0 +1,124 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <utils/Messenger.hpp> + +#include <language/utils/DataHandler.hpp> +#include <language/utils/EmbeddedData.hpp> +#include <output/GnuplotWriter.hpp> +#include <output/GnuplotWriter1D.hpp> +#include <output/NamedDiscreteFunction.hpp> +#include <output/VTKWriter.hpp> +#include <scheme/DiscreteFunctionP0.hpp> +#include <scheme/DiscreteFunctionVariant.hpp> +#include <utils/checkpointing/ReadIWriter.hpp> +#include <utils/checkpointing/WriteIWriter.hpp> + +#include <MeshDataBaseForTests.hpp> + +#include <filesystem> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("checkpointing_IWriter", "[utils/checkpointing]") +{ + std::string tmp_dirname; + { + { + if (parallel::rank() == 0) { + tmp_dirname = [&]() -> std::string { + std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX"; + return std::string{mkdtemp(&temp_filename[0])}; + }(); + } + parallel::broadcast(tmp_dirname, 0); + } + std::filesystem::path path = tmp_dirname; + const std::string filename = path / "checkpoint.h5"; + + HighFive::FileAccessProps fapl; + fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL}); + fapl.add(HighFive::MPIOCollectiveMetadata{}); + HighFive::File file = HighFive::File(filename, HighFive::File::Truncate, fapl); + + SECTION("IWriter") + { + HighFive::Group symbol_table_group = file.createGroup("symbol_table"); + HighFive::Group useless_group; + + auto p_gnuplot_writer = std::make_shared<const GnuplotWriter>("gnuplot_basename"); + checkpointing::writeIWriter("gnuplot", + EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_gnuplot_writer)}, file, + useless_group, symbol_table_group); + + auto p_gnuplot_writer_1d = std::make_shared<const GnuplotWriter1D>("gnuplot_1d_basename", 1.12); + p_gnuplot_writer_1d->periodManager().value().setSaveTime(2); + checkpointing::writeIWriter("gnuplot_1d", + EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_gnuplot_writer_1d)}, file, + useless_group, symbol_table_group); + + const std::string vtk_filename = path / "vtk_example"; + + auto mesh_v = MeshDataBaseForTests::get().cartesian1DMesh(); + + DiscreteFunctionP0<double> fh{mesh_v}; + fh.fill(0); + std::shared_ptr<const DiscreteFunctionVariant> discrete_function = std::make_shared<DiscreteFunctionVariant>(fh); + + std::shared_ptr<const INamedDiscreteData> value = + std::make_shared<const NamedDiscreteFunction>(discrete_function, "fh"); + + auto p_vtk_writer = std::make_shared<const VTKWriter>(vtk_filename, 0.02); + + p_vtk_writer->writeIfNeeded({value}, 1); + + checkpointing::writeIWriter("vtk", EmbeddedData{std::make_shared<DataHandler<const IWriter>>(p_vtk_writer)}, file, + useless_group, symbol_table_group); + + file.flush(); + + EmbeddedData read_gnuplot_writer = checkpointing::readIWriter("gnuplot", symbol_table_group); + EmbeddedData read_gnuplot_writer_1d = checkpointing::readIWriter("gnuplot_1d", symbol_table_group); + EmbeddedData read_vtk_writer = checkpointing::readIWriter("vtk", symbol_table_group); + + auto get_value = [](const EmbeddedData& embedded_data) -> const IWriter& { + return *dynamic_cast<const DataHandler<const IWriter>&>(embedded_data.get()).data_ptr(); + }; + + REQUIRE_NOTHROW(get_value(read_gnuplot_writer)); + REQUIRE_NOTHROW(dynamic_cast<const GnuplotWriter&>(get_value(read_gnuplot_writer))); + const GnuplotWriter& read_gp_writer = dynamic_cast<const GnuplotWriter&>(get_value(read_gnuplot_writer)); + REQUIRE(read_gp_writer.type() == IWriter::Type::gnuplot); + REQUIRE(read_gp_writer.baseFilename() == "gnuplot_basename"); + REQUIRE(not read_gp_writer.periodManager().has_value()); + REQUIRE(not read_gp_writer.signature().has_value()); + + REQUIRE_NOTHROW(get_value(read_gnuplot_writer_1d)); + REQUIRE_NOTHROW(dynamic_cast<const GnuplotWriter1D&>(get_value(read_gnuplot_writer_1d))); + const GnuplotWriter1D& read_gp_writer_1d = + dynamic_cast<const GnuplotWriter1D&>(get_value(read_gnuplot_writer_1d)); + REQUIRE(read_gp_writer_1d.type() == IWriter::Type::gnuplot_1d); + REQUIRE(read_gp_writer_1d.baseFilename() == "gnuplot_1d_basename"); + REQUIRE(read_gp_writer_1d.periodManager().has_value()); + REQUIRE(read_gp_writer_1d.periodManager().value().timePeriod() == 1.12); + REQUIRE(read_gp_writer_1d.periodManager().value().nextTime() == 3.12); + REQUIRE(not read_gp_writer_1d.signature().has_value()); + + REQUIRE_NOTHROW(get_value(read_vtk_writer)); + REQUIRE_NOTHROW(dynamic_cast<const VTKWriter&>(get_value(read_vtk_writer))); + const VTKWriter& read_vtk = dynamic_cast<const VTKWriter&>(get_value(read_vtk_writer)); + REQUIRE(read_vtk.type() == IWriter::Type::vtk); + REQUIRE(read_vtk.baseFilename() == vtk_filename); + REQUIRE(read_vtk.periodManager().has_value()); + REQUIRE(read_vtk.periodManager().value().timePeriod() == 0.02); + REQUIRE(read_vtk.periodManager().value().nextTime() == 1.02); + REQUIRE(read_vtk.signature().has_value()); + REQUIRE(read_vtk.signature().value() == p_vtk_writer->signature().value()); + } + } + + parallel::barrier(); + if (parallel::rank() == 0) { + std::filesystem::remove_all(std::filesystem::path{tmp_dirname}); + } +} -- GitLab