From 0f5f99db9e69d22a8ed7d45787c14dfe3ea8b67d Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Thu, 9 May 2024 20:48:16 +0200 Subject: [PATCH] Add checpoint/resume for DiscreteFunctionP0Vector --- src/utils/checkpointing/CheckpointUtils.cpp | 3 +- src/utils/checkpointing/CheckpointUtils.hpp | 53 +++++++++++++++++++++ src/utils/checkpointing/ResumeUtils.hpp | 21 ++++++-- 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp index c67514c10..4d77a25fa 100644 --- a/src/utils/checkpointing/CheckpointUtils.cpp +++ b/src/utils/checkpointing/CheckpointUtils.cpp @@ -499,8 +499,7 @@ writeDiscreteFunctionVariant(const std::string& symbol_name, } else if constexpr (is_discrete_function_P0_vector_v<DFType>) { using data_type = std::decay_t<typename DFType::data_type>; variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>)); - throw NotImplementedError("P0Vector"); - // write(variable_group, "values", discrete_function.cellArrays()); + write(variable_group, "values", discrete_function.cellArrays()); } }, discrete_function_v->discreteFunction()); diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp index 984ff037f..491d95386 100644 --- a/src/utils/checkpointing/CheckpointUtils.hpp +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -5,6 +5,7 @@ #include <language/utils/SymbolTable.hpp> #include <mesh/CellType.hpp> +#include <mesh/ItemArray.hpp> #include <mesh/ItemValue.hpp> #include <utils/Messenger.hpp> @@ -59,6 +60,49 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar dataset.createAttribute("size_per_rank", size_vector); } +template <typename DataType> +PUGS_INLINE void +write(HighFive::Group& group, const std::string& name, const Table<DataType>& table) +{ + const size_t number_of_columns = parallel::allReduceMax(table.numberOfColumns()); + if ((table.numberOfColumns() != number_of_columns) and (table.numberOfRows() > 0)) { + throw UnexpectedError("table must have same number of columns in parallel"); + } + + auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; }; + + Array<size_t> number_of_rows_per_rank = parallel::allGather(table.numberOfRows()); + size_t global_size = sum(number_of_rows_per_rank) * number_of_columns; + + size_t current_offset = 0; + for (size_t i = 0; i < parallel::rank(); ++i) { + current_offset += number_of_rows_per_rank[i] * table.numberOfColumns(); + } + std::vector<size_t> offset{current_offset, 0ul}; + std::vector<size_t> count{table.numberOfRows() * table.numberOfColumns()}; + + using data_type = std::remove_const_t<DataType>; + HighFive::DataSetCreateProps properties; + properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)})); + properties.add(HighFive::Shuffle()); + properties.add(HighFive::Deflate(3)); + + auto xfer_props = HighFive::DataTransferProps{}; + xfer_props.add(HighFive::UseCollectiveIO{}); + + HighFive::DataSet dataset = + group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props); + + std::vector<size_t> number_of_rows_per_rank_vector; + for (size_t i = 0; i < number_of_rows_per_rank.size(); ++i) { + number_of_rows_per_rank_vector.push_back(number_of_rows_per_rank[i]); + } + + dataset.createAttribute("number_of_rows_per_rank", number_of_rows_per_rank_vector); + dataset.createAttribute("number_of_columns", number_of_columns); +} + template <typename DataType, ItemType item_type, typename ConnectivityPtr> PUGS_INLINE void write(HighFive::Group& group, @@ -68,6 +112,15 @@ write(HighFive::Group& group, write(group, name, item_value.arrayView()); } +template <typename DataType, ItemType item_type, typename ConnectivityPtr> +PUGS_INLINE void +write(HighFive::Group& group, + const std::string& name, + const ItemArray<DataType, item_type, ConnectivityPtr>& item_array) +{ + write(group, name, item_array.tableView()); +} + void writeDiscreteFunctionVariant(const std::string& symbol_name, const EmbeddedData& embedded_data, HighFive::File& file, diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp index 5fcf4eb2e..2e4fb05d6 100644 --- a/src/utils/checkpointing/ResumeUtils.hpp +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -50,19 +50,32 @@ template <typename DataType> PUGS_INLINE Table<DataType> readTable(const HighFive::Group& group, const std::string& name) { + auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; }; + using data_type = std::remove_const_t<DataType>; auto dataset = group.getDataSet(name); - Table<DataType> table(dataset.getDimensions()[0], dataset.getDimensions()[1]); + + const size_t number_of_columns = dataset.getAttribute("number_of_columns").read<size_t>(); + const std::vector<size_t> number_of_rows_per_rank = + dataset.getAttribute("number_of_rows_per_rank").read<std::vector<size_t>>(); + + std::vector<size_t> offset{0, 0ul}; + for (size_t i = 0; i < parallel::rank(); ++i) { + offset[0] += number_of_rows_per_rank[i] * number_of_columns; + } + std::vector<size_t> count{number_of_rows_per_rank[parallel::rank()]}; + + Table<DataType> table(number_of_rows_per_rank[parallel::rank()], number_of_columns); if constexpr (std::is_same_v<CellType, data_type>) { using base_type = std::underlying_type_t<CellType>; - dataset.read_raw(reinterpret_cast<base_type*>(&(table(0, 0)))); + dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(table))); } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { using base_type = typename data_type::base_type; - dataset.read_raw<base_type>(reinterpret_cast<base_type*>(&(table(0, 0)))); + dataset.select(offset, count).read_raw<base_type>(reinterpret_cast<base_type*>(get_address(table))); } else { - dataset.read_raw<data_type>(&(table(0, 0))); + dataset.select(offset, count).read_raw<data_type>(get_address(table)); } return table; -- GitLab