diff --git a/src/utils/checkpointing/Checkpoint.cpp b/src/utils/checkpointing/Checkpoint.cpp index ea5ede9cb1d9ac4954e742eedfc8400399891071..e7325fa862033a5865ee9cddf8f545d485f68808 100644 --- a/src/utils/checkpointing/Checkpoint.cpp +++ b/src/utils/checkpointing/Checkpoint.cpp @@ -20,13 +20,12 @@ #ifdef PUGS_HAS_HDF5 #include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> #include <language/utils/DataHandler.hpp> #include <mesh/MeshVariant.hpp> #include <utils/GlobalVariableManager.hpp> #include <utils/RandomEngine.hpp> -#include <language/utils/CheckpointResumeRepository.hpp> - void checkpoint() { @@ -34,11 +33,15 @@ checkpoint() auto create_props = HighFive::FileCreateProps{}; create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); + HighFive::FileAccessProps fapl; + fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL}); + fapl.add(HighFive::MPIOCollectiveMetadata{}); + uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; - HighFive::File file("checkpoint.h5", file_openmode, create_props); + HighFive::File file("checkpoint.h5", file_openmode, create_props, fapl); std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number); @@ -57,7 +60,7 @@ checkpoint() checkpoint.createAttribute("creation_date", time); checkpoint.createAttribute("name", checkpoint_name); checkpoint.createAttribute("id", checkpoint_id); - checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); + checkpoint.createAttribute("data.pgs", ASTExecutionStack::getInstance().fileContent()); { HighFive::Group random_seed = checkpoint.createGroup("singleton/random_seed"); diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp index 6e9af741e9a468e143f1571b672880881b22f2d2..984ff037f8dbae0b393127e59cc545f4e2cd95c1 100644 --- a/src/utils/checkpointing/CheckpointUtils.hpp +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -6,33 +6,57 @@ #include <language/utils/SymbolTable.hpp> #include <mesh/CellType.hpp> #include <mesh/ItemValue.hpp> +#include <utils/Messenger.hpp> template <typename DataType> PUGS_INLINE void write(HighFive::Group& group, const std::string& name, const Array<DataType>& array) { + auto get_address = [](auto& x) { return (x.size() > 0) ? &(x[0]) : nullptr; }; + + Array<size_t> size_per_rank = parallel::allGather(array.size()); + size_t global_size = sum(size_per_rank); + + size_t current_offset = 0; + for (size_t i = 0; i < parallel::rank(); ++i) { + current_offset += size_per_rank[i]; + } + std::vector<size_t> offset{current_offset, 0ul}; + std::vector<size_t> count{array.size()}; + using data_type = std::remove_const_t<DataType>; HighFive::DataSetCreateProps properties; - properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, array.size())})); + properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)})); properties.add(HighFive::Shuffle()); properties.add(HighFive::Deflate(3)); + auto xfer_props = HighFive::DataTransferProps{}; + xfer_props.add(HighFive::UseCollectiveIO{}); + + HighFive::DataSet dataset; if constexpr (std::is_same_v<CellType, data_type>) { using base_type = std::underlying_type_t<CellType>; - auto dataset = - group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); - dataset.template write_raw<base_type>(reinterpret_cast<const base_type*>(&(array[0]))); + dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count) + .template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(array)), xfer_props); } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { using base_type = typename data_type::base_type; - auto dataset = - group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); - dataset.template write_raw<base_type>(reinterpret_cast<const base_type*>(&(array[0]))); + + dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count) + .template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(array)), xfer_props); } else { - auto dataset = - group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); - dataset.template write_raw<data_type>(&(array[0])); + dataset = group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count).template write_raw<data_type>(get_address(array), xfer_props); + } + + std::vector<size_t> size_vector; + for (size_t i = 0; i < size_per_rank.size(); ++i) { + size_vector.push_back(size_per_rank[i]); } + + dataset.createAttribute("size_per_rank", size_vector); } template <typename DataType, ItemType item_type, typename ConnectivityPtr> diff --git a/src/utils/checkpointing/PrintCheckpointInfo.cpp b/src/utils/checkpointing/PrintCheckpointInfo.cpp index 561391b05be32933b5c76edda785951107b535ab..43dbcdc55ab07a59e7747eeb3f5540f82e592cc8 100644 --- a/src/utils/checkpointing/PrintCheckpointInfo.cpp +++ b/src/utils/checkpointing/PrintCheckpointInfo.cpp @@ -134,6 +134,12 @@ printCheckpointInfo(const std::string& filename) printAttributeValue<std::string>(attribute); break; } + default: { + std::ostringstream error_msg; + error_msg << "invalid data type class '" << rang::fgB::yellow << data_type.string() << rang::fg::reset + << "' for symbol " << rang::fgB::cyan << symbol_name << rang::fg::reset; + throw UnexpectedError(error_msg.str()); + } } std::cout << rang::style::reset << '\n'; diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp index 1882d2fc7597a96fd056b360258171d083687bda..5fcf4eb2e7d8d78f6f1e50bf92f7d1268a0b9c14 100644 --- a/src/utils/checkpointing/ResumeUtils.hpp +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -7,24 +7,40 @@ #include <mesh/CellType.hpp> #include <mesh/ItemArray.hpp> #include <mesh/ItemValue.hpp> +#include <utils/Messenger.hpp> template <typename DataType> PUGS_INLINE Array<DataType> readArray(const HighFive::Group& group, const std::string& name) { + auto get_address = [](auto& x) { return (x.size() > 0) ? &(x[0]) : nullptr; }; + using data_type = std::remove_const_t<DataType>; auto dataset = group.getDataSet(name); - Array<DataType> array(dataset.getElementCount()); + + std::vector<size_t> size_per_rank = dataset.getAttribute("size_per_rank").read<std::vector<size_t>>(); + + if (size_per_rank.size() != parallel::size()) { + throw NormalError("cannot change number of processes"); + } + + std::vector<size_t> offset{0, 0ul}; + for (size_t i = 0; i < parallel::rank(); ++i) { + offset[0] += size_per_rank[i]; + } + std::vector<size_t> count{size_per_rank[parallel::rank()]}; + + Array<DataType> array(size_per_rank[parallel::rank()]); if constexpr (std::is_same_v<CellType, data_type>) { using base_type = std::underlying_type_t<CellType>; - dataset.read_raw(reinterpret_cast<base_type*>(&(array[0]))); + dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(array))); } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { using base_type = typename data_type::base_type; - dataset.read_raw(reinterpret_cast<base_type*>(&(array[0]))); + dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(array))); } else { - dataset.read_raw(&(array[0])); + dataset.select(offset, count).read_raw(get_address(array)); } return array; diff --git a/src/utils/checkpointing/ResumingData.cpp b/src/utils/checkpointing/ResumingData.cpp index 86dd4428919dce9e6cc6b1dfbf1dd735b527bb89..8e121cd1c3614c799e57858e0539a764cf3b95bd 100644 --- a/src/utils/checkpointing/ResumingData.cpp +++ b/src/utils/checkpointing/ResumingData.cpp @@ -234,47 +234,49 @@ ResumingData::_getMeshVariantList(const HighFive::Group& checkpoint) void ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table) { - size_t symbol_table_id = 0; - const HighFive::Group function_group = checkpoint.getGroup("functions"); - while (p_symbol_table.use_count() > 0) { - for (auto symbol : p_symbol_table->symbolList()) { - if (symbol.attributes().dataType() == ASTNodeDataType::function_t) { - if (not function_group.exist(symbol.name())) { - std::ostringstream error_msg; - error_msg << "cannot find function " << rang::fgB::yellow << symbol.name() << rang::fg::reset << " in " - << rang::fgB::cyan << checkpoint.getFile().getName() << rang::fg::reset; - throw NormalError(error_msg.str()); - } else { - const HighFive::Group function = function_group.getGroup(symbol.name()); - const size_t stored_function_id = function.getAttribute("id").read<size_t>(); - const size_t function_id = std::get<size_t>(symbol.attributes().value()); - if (symbol_table_id != function.getAttribute("symbol_table_id").read<size_t>()) { + if (checkpoint.exist("functions")) { + size_t symbol_table_id = 0; + const HighFive::Group function_group = checkpoint.getGroup("functions"); + while (p_symbol_table.use_count() > 0) { + for (auto symbol : p_symbol_table->symbolList()) { + if (symbol.attributes().dataType() == ASTNodeDataType::function_t) { + if (not function_group.exist(symbol.name())) { std::ostringstream error_msg; - error_msg << "symbol table of function " << rang::fgB::yellow << symbol.name() << rang::fg::reset - << " does not match the one stored in " << rang::fgB::cyan << checkpoint.getFile().getName() - << rang::fg::reset; - throw NormalError(error_msg.str()); - } else if (function_id != stored_function_id) { - std::ostringstream error_msg; - error_msg << "id (" << function_id << ") of function " << rang::fgB::yellow << symbol.name() - << rang::fg::reset << " does not match the one stored in " << rang::fgB::cyan - << checkpoint.getFile().getName() << rang::fg::reset << "(" << stored_function_id << ")"; + error_msg << "cannot find function " << rang::fgB::yellow << symbol.name() << rang::fg::reset << " in " + << rang::fgB::cyan << checkpoint.getFile().getName() << rang::fg::reset; throw NormalError(error_msg.str()); } else { - if (m_id_to_function_symbol_id_map.contains(function_id)) { + const HighFive::Group function = function_group.getGroup(symbol.name()); + const size_t stored_function_id = function.getAttribute("id").read<size_t>(); + const size_t function_id = std::get<size_t>(symbol.attributes().value()); + if (symbol_table_id != function.getAttribute("symbol_table_id").read<size_t>()) { + std::ostringstream error_msg; + error_msg << "symbol table of function " << rang::fgB::yellow << symbol.name() << rang::fg::reset + << " does not match the one stored in " << rang::fgB::cyan << checkpoint.getFile().getName() + << rang::fg::reset; + throw NormalError(error_msg.str()); + } else if (function_id != stored_function_id) { std::ostringstream error_msg; error_msg << "id (" << function_id << ") of function " << rang::fgB::yellow << symbol.name() - << rang::fg::reset << " is duplicated"; - throw UnexpectedError(error_msg.str()); + << rang::fg::reset << " does not match the one stored in " << rang::fgB::cyan + << checkpoint.getFile().getName() << rang::fg::reset << "(" << stored_function_id << ")"; + throw NormalError(error_msg.str()); + } else { + if (m_id_to_function_symbol_id_map.contains(function_id)) { + std::ostringstream error_msg; + error_msg << "id (" << function_id << ") of function " << rang::fgB::yellow << symbol.name() + << rang::fg::reset << " is duplicated"; + throw UnexpectedError(error_msg.str()); + } + m_id_to_function_symbol_id_map[function_id] = + std::make_shared<FunctionSymbolId>(function_id, p_symbol_table); } - m_id_to_function_symbol_id_map[function_id] = - std::make_shared<FunctionSymbolId>(function_id, p_symbol_table); } } } + p_symbol_table = p_symbol_table->parentTable(); + ++symbol_table_id; } - p_symbol_table = p_symbol_table->parentTable(); - ++symbol_table_id; } } diff --git a/src/utils/checkpointing/ResumingUtils.cpp b/src/utils/checkpointing/ResumingUtils.cpp index 23b91d6689a914650c966c661282427e0574ed18..82e6610519a12710de96ac661d62f1590fc0277f 100644 --- a/src/utils/checkpointing/ResumingUtils.cpp +++ b/src/utils/checkpointing/ResumingUtils.cpp @@ -12,7 +12,7 @@ std::string resumingDatafile(const std::string& filename) { HighFive::File file(filename, HighFive::File::ReadOnly); - return file.getGroup("/resuming_checkpoint").getDataSet("data.pgs").read<std::string>(); + return file.getGroup("/resuming_checkpoint").getAttribute("data.pgs").read<std::string>(); } #else // PUGS_HAS_HDF5