diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp index c67514c1002216386b63ac3591709d3b3760213e..4d77a25fa58ec05042dd1892981b56a1b6510328 100644 --- a/src/utils/checkpointing/CheckpointUtils.cpp +++ b/src/utils/checkpointing/CheckpointUtils.cpp @@ -499,8 +499,7 @@ writeDiscreteFunctionVariant(const std::string& symbol_name, } else if constexpr (is_discrete_function_P0_vector_v<DFType>) { using data_type = std::decay_t<typename DFType::data_type>; variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>)); - throw NotImplementedError("P0Vector"); - // write(variable_group, "values", discrete_function.cellArrays()); + write(variable_group, "values", discrete_function.cellArrays()); } }, discrete_function_v->discreteFunction()); diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp index 984ff037f8dbae0b393127e59cc545f4e2cd95c1..491d95386fa982c4f51c03e5160933f5fe3f8270 100644 --- a/src/utils/checkpointing/CheckpointUtils.hpp +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -5,6 +5,7 @@ #include <language/utils/SymbolTable.hpp> #include <mesh/CellType.hpp> +#include <mesh/ItemArray.hpp> #include <mesh/ItemValue.hpp> #include <utils/Messenger.hpp> @@ -59,6 +60,49 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar dataset.createAttribute("size_per_rank", size_vector); } +template <typename DataType> +PUGS_INLINE void +write(HighFive::Group& group, const std::string& name, const Table<DataType>& table) +{ + const size_t number_of_columns = parallel::allReduceMax(table.numberOfColumns()); + if ((table.numberOfColumns() != number_of_columns) and (table.numberOfRows() > 0)) { + throw UnexpectedError("table must have same number of columns in parallel"); + } + + auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; }; + + Array<size_t> number_of_rows_per_rank = parallel::allGather(table.numberOfRows()); + size_t global_size = sum(number_of_rows_per_rank) * number_of_columns; + + size_t current_offset = 0; + for (size_t i = 0; i < parallel::rank(); ++i) { + current_offset += number_of_rows_per_rank[i] * table.numberOfColumns(); + } + std::vector<size_t> offset{current_offset, 0ul}; + std::vector<size_t> count{table.numberOfRows() * table.numberOfColumns()}; + + using data_type = std::remove_const_t<DataType>; + HighFive::DataSetCreateProps properties; + properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)})); + properties.add(HighFive::Shuffle()); + properties.add(HighFive::Deflate(3)); + + auto xfer_props = HighFive::DataTransferProps{}; + xfer_props.add(HighFive::UseCollectiveIO{}); + + HighFive::DataSet dataset = + group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties); + dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props); + + std::vector<size_t> number_of_rows_per_rank_vector; + for (size_t i = 0; i < number_of_rows_per_rank.size(); ++i) { + number_of_rows_per_rank_vector.push_back(number_of_rows_per_rank[i]); + } + + dataset.createAttribute("number_of_rows_per_rank", number_of_rows_per_rank_vector); + dataset.createAttribute("number_of_columns", number_of_columns); +} + template <typename DataType, ItemType item_type, typename ConnectivityPtr> PUGS_INLINE void write(HighFive::Group& group, @@ -68,6 +112,15 @@ write(HighFive::Group& group, write(group, name, item_value.arrayView()); } +template <typename DataType, ItemType item_type, typename ConnectivityPtr> +PUGS_INLINE void +write(HighFive::Group& group, + const std::string& name, + const ItemArray<DataType, item_type, ConnectivityPtr>& item_array) +{ + write(group, name, item_array.tableView()); +} + void writeDiscreteFunctionVariant(const std::string& symbol_name, const EmbeddedData& embedded_data, HighFive::File& file, diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp index 5fcf4eb2e7d8d78f6f1e50bf92f7d1268a0b9c14..2e4fb05d62c47c9b7a0919904210dc1dcc6a4f67 100644 --- a/src/utils/checkpointing/ResumeUtils.hpp +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -50,19 +50,32 @@ template <typename DataType> PUGS_INLINE Table<DataType> readTable(const HighFive::Group& group, const std::string& name) { + auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; }; + using data_type = std::remove_const_t<DataType>; auto dataset = group.getDataSet(name); - Table<DataType> table(dataset.getDimensions()[0], dataset.getDimensions()[1]); + + const size_t number_of_columns = dataset.getAttribute("number_of_columns").read<size_t>(); + const std::vector<size_t> number_of_rows_per_rank = + dataset.getAttribute("number_of_rows_per_rank").read<std::vector<size_t>>(); + + std::vector<size_t> offset{0, 0ul}; + for (size_t i = 0; i < parallel::rank(); ++i) { + offset[0] += number_of_rows_per_rank[i] * number_of_columns; + } + std::vector<size_t> count{number_of_rows_per_rank[parallel::rank()]}; + + Table<DataType> table(number_of_rows_per_rank[parallel::rank()], number_of_columns); if constexpr (std::is_same_v<CellType, data_type>) { using base_type = std::underlying_type_t<CellType>; - dataset.read_raw(reinterpret_cast<base_type*>(&(table(0, 0)))); + dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(table))); } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { using base_type = typename data_type::base_type; - dataset.read_raw<base_type>(reinterpret_cast<base_type*>(&(table(0, 0)))); + dataset.select(offset, count).read_raw<base_type>(reinterpret_cast<base_type*>(get_address(table))); } else { - dataset.read_raw<data_type>(&(table(0, 0))); + dataset.select(offset, count).read_raw<data_type>(get_address(table)); } return table;