Skip to content
Snippets Groups Projects
Commit e464abba authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add support for parallel checkpoint/resume

parent 9d3c28e6
No related branches found
No related tags found
1 merge request!199Integrate checkpointing
...@@ -20,13 +20,12 @@ ...@@ -20,13 +20,12 @@
#ifdef PUGS_HAS_HDF5 #ifdef PUGS_HAS_HDF5
#include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <language/utils/DataHandler.hpp> #include <language/utils/DataHandler.hpp>
#include <mesh/MeshVariant.hpp> #include <mesh/MeshVariant.hpp>
#include <utils/GlobalVariableManager.hpp> #include <utils/GlobalVariableManager.hpp>
#include <utils/RandomEngine.hpp> #include <utils/RandomEngine.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
void void
checkpoint() checkpoint()
{ {
...@@ -34,11 +33,15 @@ checkpoint() ...@@ -34,11 +33,15 @@ checkpoint()
auto create_props = HighFive::FileCreateProps{}; auto create_props = HighFive::FileCreateProps{};
create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0));
HighFive::FileAccessProps fapl;
fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
fapl.add(HighFive::MPIOCollectiveMetadata{});
uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber();
const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite;
HighFive::File file("checkpoint.h5", file_openmode, create_props); HighFive::File file("checkpoint.h5", file_openmode, create_props, fapl);
std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number); std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number);
...@@ -57,7 +60,7 @@ checkpoint() ...@@ -57,7 +60,7 @@ checkpoint()
checkpoint.createAttribute("creation_date", time); checkpoint.createAttribute("creation_date", time);
checkpoint.createAttribute("name", checkpoint_name); checkpoint.createAttribute("name", checkpoint_name);
checkpoint.createAttribute("id", checkpoint_id); checkpoint.createAttribute("id", checkpoint_id);
checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); checkpoint.createAttribute("data.pgs", ASTExecutionStack::getInstance().fileContent());
{ {
HighFive::Group random_seed = checkpoint.createGroup("singleton/random_seed"); HighFive::Group random_seed = checkpoint.createGroup("singleton/random_seed");
......
...@@ -6,33 +6,57 @@ ...@@ -6,33 +6,57 @@
#include <language/utils/SymbolTable.hpp> #include <language/utils/SymbolTable.hpp>
#include <mesh/CellType.hpp> #include <mesh/CellType.hpp>
#include <mesh/ItemValue.hpp> #include <mesh/ItemValue.hpp>
#include <utils/Messenger.hpp>
template <typename DataType> template <typename DataType>
PUGS_INLINE void PUGS_INLINE void
write(HighFive::Group& group, const std::string& name, const Array<DataType>& array) write(HighFive::Group& group, const std::string& name, const Array<DataType>& array)
{ {
auto get_address = [](auto& x) { return (x.size() > 0) ? &(x[0]) : nullptr; };
Array<size_t> size_per_rank = parallel::allGather(array.size());
size_t global_size = sum(size_per_rank);
size_t current_offset = 0;
for (size_t i = 0; i < parallel::rank(); ++i) {
current_offset += size_per_rank[i];
}
std::vector<size_t> offset{current_offset, 0ul};
std::vector<size_t> count{array.size()};
using data_type = std::remove_const_t<DataType>; using data_type = std::remove_const_t<DataType>;
HighFive::DataSetCreateProps properties; HighFive::DataSetCreateProps properties;
properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, array.size())})); properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)}));
properties.add(HighFive::Shuffle()); properties.add(HighFive::Shuffle());
properties.add(HighFive::Deflate(3)); properties.add(HighFive::Deflate(3));
auto xfer_props = HighFive::DataTransferProps{};
xfer_props.add(HighFive::UseCollectiveIO{});
HighFive::DataSet dataset;
if constexpr (std::is_same_v<CellType, data_type>) { if constexpr (std::is_same_v<CellType, data_type>) {
using base_type = std::underlying_type_t<CellType>; using base_type = std::underlying_type_t<CellType>;
auto dataset = dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties);
group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); dataset.select(offset, count)
dataset.template write_raw<base_type>(reinterpret_cast<const base_type*>(&(array[0]))); .template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(array)), xfer_props);
} else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or
(std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) {
using base_type = typename data_type::base_type; using base_type = typename data_type::base_type;
auto dataset =
group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); dataset = group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties);
dataset.template write_raw<base_type>(reinterpret_cast<const base_type*>(&(array[0]))); dataset.select(offset, count)
.template write_raw<base_type>(reinterpret_cast<const base_type*>(get_address(array)), xfer_props);
} else { } else {
auto dataset = dataset = group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties);
group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); dataset.select(offset, count).template write_raw<data_type>(get_address(array), xfer_props);
dataset.template write_raw<data_type>(&(array[0])); }
std::vector<size_t> size_vector;
for (size_t i = 0; i < size_per_rank.size(); ++i) {
size_vector.push_back(size_per_rank[i]);
} }
dataset.createAttribute("size_per_rank", size_vector);
} }
template <typename DataType, ItemType item_type, typename ConnectivityPtr> template <typename DataType, ItemType item_type, typename ConnectivityPtr>
......
...@@ -134,6 +134,12 @@ printCheckpointInfo(const std::string& filename) ...@@ -134,6 +134,12 @@ printCheckpointInfo(const std::string& filename)
printAttributeValue<std::string>(attribute); printAttributeValue<std::string>(attribute);
break; break;
} }
default: {
std::ostringstream error_msg;
error_msg << "invalid data type class '" << rang::fgB::yellow << data_type.string() << rang::fg::reset
<< "' for symbol " << rang::fgB::cyan << symbol_name << rang::fg::reset;
throw UnexpectedError(error_msg.str());
}
} }
std::cout << rang::style::reset << '\n'; std::cout << rang::style::reset << '\n';
......
...@@ -7,24 +7,40 @@ ...@@ -7,24 +7,40 @@
#include <mesh/CellType.hpp> #include <mesh/CellType.hpp>
#include <mesh/ItemArray.hpp> #include <mesh/ItemArray.hpp>
#include <mesh/ItemValue.hpp> #include <mesh/ItemValue.hpp>
#include <utils/Messenger.hpp>
template <typename DataType> template <typename DataType>
PUGS_INLINE Array<DataType> PUGS_INLINE Array<DataType>
readArray(const HighFive::Group& group, const std::string& name) readArray(const HighFive::Group& group, const std::string& name)
{ {
auto get_address = [](auto& x) { return (x.size() > 0) ? &(x[0]) : nullptr; };
using data_type = std::remove_const_t<DataType>; using data_type = std::remove_const_t<DataType>;
auto dataset = group.getDataSet(name); auto dataset = group.getDataSet(name);
Array<DataType> array(dataset.getElementCount());
std::vector<size_t> size_per_rank = dataset.getAttribute("size_per_rank").read<std::vector<size_t>>();
if (size_per_rank.size() != parallel::size()) {
throw NormalError("cannot change number of processes");
}
std::vector<size_t> offset{0, 0ul};
for (size_t i = 0; i < parallel::rank(); ++i) {
offset[0] += size_per_rank[i];
}
std::vector<size_t> count{size_per_rank[parallel::rank()]};
Array<DataType> array(size_per_rank[parallel::rank()]);
if constexpr (std::is_same_v<CellType, data_type>) { if constexpr (std::is_same_v<CellType, data_type>) {
using base_type = std::underlying_type_t<CellType>; using base_type = std::underlying_type_t<CellType>;
dataset.read_raw(reinterpret_cast<base_type*>(&(array[0]))); dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(array)));
} else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or
(std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) {
using base_type = typename data_type::base_type; using base_type = typename data_type::base_type;
dataset.read_raw(reinterpret_cast<base_type*>(&(array[0]))); dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(array)));
} else { } else {
dataset.read_raw(&(array[0])); dataset.select(offset, count).read_raw(get_address(array));
} }
return array; return array;
......
...@@ -234,6 +234,7 @@ ResumingData::_getMeshVariantList(const HighFive::Group& checkpoint) ...@@ -234,6 +234,7 @@ ResumingData::_getMeshVariantList(const HighFive::Group& checkpoint)
void void
ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table) ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table)
{ {
if (checkpoint.exist("functions")) {
size_t symbol_table_id = 0; size_t symbol_table_id = 0;
const HighFive::Group function_group = checkpoint.getGroup("functions"); const HighFive::Group function_group = checkpoint.getGroup("functions");
while (p_symbol_table.use_count() > 0) { while (p_symbol_table.use_count() > 0) {
...@@ -277,6 +278,7 @@ ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr ...@@ -277,6 +278,7 @@ ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr
++symbol_table_id; ++symbol_table_id;
} }
} }
}
void void
ResumingData::readData(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table) ResumingData::readData(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table)
......
...@@ -12,7 +12,7 @@ std::string ...@@ -12,7 +12,7 @@ std::string
resumingDatafile(const std::string& filename) resumingDatafile(const std::string& filename)
{ {
HighFive::File file(filename, HighFive::File::ReadOnly); HighFive::File file(filename, HighFive::File::ReadOnly);
return file.getGroup("/resuming_checkpoint").getDataSet("data.pgs").read<std::string>(); return file.getGroup("/resuming_checkpoint").getAttribute("data.pgs").read<std::string>();
} }
#else // PUGS_HAS_HDF5 #else // PUGS_HAS_HDF5
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment