diff --git a/src/dev_utils/ParallelChecker.hpp b/src/dev_utils/ParallelChecker.hpp index 22253202a282f42ed51aeccc36601ccc37b47474..4db7c33b7874cb9f6f9665147c293bca9166620b 100644 --- a/src/dev_utils/ParallelChecker.hpp +++ b/src/dev_utils/ParallelChecker.hpp @@ -23,10 +23,17 @@ void parallel_check(const ItemArray<DataType, item_type, ConnectivityPtr>& item_ const std::string& name, const SourceLocation& source_location = SourceLocation{}); -#ifdef PUGS_HAS_HDF5 - class ParallelChecker { + public: + enum class Mode + { + automatic, // write in sequential, read in parallel + read, + write + }; + +#ifdef PUGS_HAS_HDF5 template <typename DataType, ItemType item_type, typename ConnectivityPtr> friend void parallel_check(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value, const std::string& name, @@ -68,14 +75,13 @@ class ParallelChecker static ParallelChecker* m_instance; + Mode m_mode = Mode::automatic; size_t m_tag = 0; std::string m_filename = "testme/parallel_checker.h5"; ParallelChecker() = default; - std::unique_ptr<HighFive::SilenceHDF5> m_silence_hdf5 = std::make_unique<HighFive::SilenceHDF5>(true); - HighFive::File _createOrOpenFileRW() const { @@ -319,16 +325,6 @@ class ParallelChecker } } - public: - static void create(); - static void destroy(); - - static ParallelChecker& - instance() - { - return *m_instance; - } - private: template <typename DataType, ItemType item_type, typename ConnectivityPtr> void @@ -646,16 +642,9 @@ class ParallelChecker throw NormalError(e.what()); } } -}; -#else // PUGS_HAS_HDF5 +#else // PUGS_HAS_HDF5 -class ParallelChecker -{ - private: - static ParallelChecker* m_instance; - - public: template <typename T> void write(const T&, const std::string&, const SourceLocation&) @@ -669,7 +658,9 @@ class ParallelChecker { throw UnexpectedError("parallel checker cannot be used without HDF5 support"); } +#endif // PUGS_HAS_HDF5 + public: static void create(); static void destroy(); @@ -678,9 +669,39 @@ class ParallelChecker { return *m_instance; } -}; -#endif // PUGS_HAS_HDF5 + void + setMode(const Mode& mode) + { + m_mode = mode; + } + + bool + isWriting() const + { + bool is_writting = false; + switch (m_mode) { + case Mode::automatic: { + is_writting = (parallel::size() == 1); + break; + } + case Mode::write: { + is_writting = true; + break; + } + case Mode::read: { + is_writting = false; + break; + } + } + + if ((is_writting) and (parallel::size() > 1)) { + throw NotImplementedError("parallel check write in parallel"); + } + + return is_writting; + } +}; template <typename DataType, ItemType item_type, typename ConnectivityPtr> void @@ -688,9 +709,8 @@ parallel_check(const ItemArray<DataType, item_type, ConnectivityPtr>& item_array const std::string& name, const SourceLocation& source_location) { - const bool write_mode = (parallel::size() == 1); - - if (write_mode) { + HighFive::SilenceHDF5 m_silence_hdf5{true}; + if (ParallelChecker::instance().isWriting()) { ParallelChecker::instance().write(item_array, name, source_location); } else { ParallelChecker::instance().compare(item_array, name, source_location); @@ -703,9 +723,8 @@ parallel_check(const ItemValue<DataType, item_type, ConnectivityPtr>& item_value const std::string& name, const SourceLocation& source_location) { - const bool write_mode = (parallel::size() == 1); - - if (write_mode) { + HighFive::SilenceHDF5 m_silence_hdf5{true}; + if (ParallelChecker::instance().isWriting()) { ParallelChecker::instance().write(item_value, name, source_location); } else { ParallelChecker::instance().compare(item_value, name, source_location); diff --git a/src/main.cpp b/src/main.cpp index b89a36ba7931dd0c58e3d0074b7862fe1ca32690..0cae9351ff97396484cf28f52c242692ce9299f5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -11,6 +11,8 @@ int main(int argc, char* argv[]) { + ParallelChecker::create(); + std::string filename = initialize(argc, argv); SynchronizerManager::create(); @@ -19,11 +21,9 @@ main(int argc, char* argv[]) MeshDataManager::create(); DualConnectivityManager::create(); DualMeshManager::create(); - ParallelChecker::create(); parser(filename); - ParallelChecker::destroy(); DualMeshManager::destroy(); DualConnectivityManager::destroy(); MeshDataManager::destroy(); @@ -33,5 +33,7 @@ main(int argc, char* argv[]) finalize(); + ParallelChecker::destroy(); + return 0; } diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index cbb06c49eb673d48bdcff1854015305411f56a2e..b54135bb068df963560675e9fe8c85b31c465aec 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -1,5 +1,6 @@ #include <utils/PugsUtils.hpp> +#include <dev_utils/ParallelChecker.hpp> #include <utils/BacktraceManager.hpp> #include <utils/BuildInfo.hpp> #include <utils/CommunicatorManager.hpp> @@ -87,6 +88,8 @@ initialize(int& argc, char* argv[]) bool enable_signals = true; int nb_threads = -1; + ParallelChecker::Mode pc_mode = ParallelChecker::Mode::automatic; + std::string filename; { CLI::App app{"pugs help"}; @@ -123,6 +126,14 @@ initialize(int& argc, char* argv[]) app.add_flag("--reproducible-sums,!--no-reproducible-sums", show_preamble, "Special treatment of array sums to ensure reproducibility [default: true]"); + std::map<std::string, ParallelChecker::Mode> pc_mode_map{{"auto", ParallelChecker::Mode::automatic}, + {"write", ParallelChecker::Mode::write}, + {"read", ParallelChecker::Mode::read}}; + app + .add_option("--parallel-checker-mode", pc_mode, + "Parallel checker mode (auto: sequential write/parallel read) [default: auto]") + ->transform(CLI::CheckedTransformer(pc_mode_map)); + int mpi_split_color = -1; app.add_option("--mpi-split-color", mpi_split_color, "Sets the MPI split color value (for MPMD applications)") ->check(CLI::Range(0, std::numeric_limits<decltype(mpi_split_color)>::max())); @@ -166,6 +177,8 @@ initialize(int& argc, char* argv[]) Kokkos::initialize(args); } + ParallelChecker::instance().setMode(pc_mode); + if (ConsoleManager::showPreamble()) { std::cout << "----------------- " << rang::fg::green << "pugs exec info" << rang::fg::reset << " ----------------------" << '\n';