Skip to content
Snippets Groups Projects
Commit 0f5f99db authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Add checpoint/resume for DiscreteFunctionP0Vector

parent e464abba
Branches
No related tags found
1 merge request!199Integrate checkpointing
...@@ -499,8 +499,7 @@ writeDiscreteFunctionVariant(const std::string& symbol_name, ...@@ -499,8 +499,7 @@ writeDiscreteFunctionVariant(const std::string& symbol_name,
} else if constexpr (is_discrete_function_P0_vector_v<DFType>) { } else if constexpr (is_discrete_function_P0_vector_v<DFType>) {
using data_type = std::decay_t<typename DFType::data_type>; using data_type = std::decay_t<typename DFType::data_type>;
variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>)); variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>));
throw NotImplementedError("P0Vector"); write(variable_group, "values", discrete_function.cellArrays());
// write(variable_group, "values", discrete_function.cellArrays());
} }
}, },
discrete_function_v->discreteFunction()); discrete_function_v->discreteFunction());
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <language/utils/SymbolTable.hpp> #include <language/utils/SymbolTable.hpp>
#include <mesh/CellType.hpp> #include <mesh/CellType.hpp>
#include <mesh/ItemArray.hpp>
#include <mesh/ItemValue.hpp> #include <mesh/ItemValue.hpp>
#include <utils/Messenger.hpp> #include <utils/Messenger.hpp>
...@@ -59,6 +60,49 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar ...@@ -59,6 +60,49 @@ write(HighFive::Group& group, const std::string& name, const Array<DataType>& ar
dataset.createAttribute("size_per_rank", size_vector); dataset.createAttribute("size_per_rank", size_vector);
} }
template <typename DataType>
PUGS_INLINE void
write(HighFive::Group& group, const std::string& name, const Table<DataType>& table)
{
const size_t number_of_columns = parallel::allReduceMax(table.numberOfColumns());
if ((table.numberOfColumns() != number_of_columns) and (table.numberOfRows() > 0)) {
throw UnexpectedError("table must have same number of columns in parallel");
}
auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; };
Array<size_t> number_of_rows_per_rank = parallel::allGather(table.numberOfRows());
size_t global_size = sum(number_of_rows_per_rank) * number_of_columns;
size_t current_offset = 0;
for (size_t i = 0; i < parallel::rank(); ++i) {
current_offset += number_of_rows_per_rank[i] * table.numberOfColumns();
}
std::vector<size_t> offset{current_offset, 0ul};
std::vector<size_t> count{table.numberOfRows() * table.numberOfColumns()};
using data_type = std::remove_const_t<DataType>;
HighFive::DataSetCreateProps properties;
properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, global_size)}));
properties.add(HighFive::Shuffle());
properties.add(HighFive::Deflate(3));
auto xfer_props = HighFive::DataTransferProps{};
xfer_props.add(HighFive::UseCollectiveIO{});
HighFive::DataSet dataset =
group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{global_size}}, properties);
dataset.select(offset, count).template write_raw<data_type>(get_address(table), xfer_props);
std::vector<size_t> number_of_rows_per_rank_vector;
for (size_t i = 0; i < number_of_rows_per_rank.size(); ++i) {
number_of_rows_per_rank_vector.push_back(number_of_rows_per_rank[i]);
}
dataset.createAttribute("number_of_rows_per_rank", number_of_rows_per_rank_vector);
dataset.createAttribute("number_of_columns", number_of_columns);
}
template <typename DataType, ItemType item_type, typename ConnectivityPtr> template <typename DataType, ItemType item_type, typename ConnectivityPtr>
PUGS_INLINE void PUGS_INLINE void
write(HighFive::Group& group, write(HighFive::Group& group,
...@@ -68,6 +112,15 @@ write(HighFive::Group& group, ...@@ -68,6 +112,15 @@ write(HighFive::Group& group,
write(group, name, item_value.arrayView()); write(group, name, item_value.arrayView());
} }
template <typename DataType, ItemType item_type, typename ConnectivityPtr>
PUGS_INLINE void
write(HighFive::Group& group,
const std::string& name,
const ItemArray<DataType, item_type, ConnectivityPtr>& item_array)
{
write(group, name, item_array.tableView());
}
void writeDiscreteFunctionVariant(const std::string& symbol_name, void writeDiscreteFunctionVariant(const std::string& symbol_name,
const EmbeddedData& embedded_data, const EmbeddedData& embedded_data,
HighFive::File& file, HighFive::File& file,
......
...@@ -50,19 +50,32 @@ template <typename DataType> ...@@ -50,19 +50,32 @@ template <typename DataType>
PUGS_INLINE Table<DataType> PUGS_INLINE Table<DataType>
readTable(const HighFive::Group& group, const std::string& name) readTable(const HighFive::Group& group, const std::string& name)
{ {
auto get_address = [](auto& t) { return (t.numberOfRows() * t.numberOfColumns() > 0) ? &(t(0, 0)) : nullptr; };
using data_type = std::remove_const_t<DataType>; using data_type = std::remove_const_t<DataType>;
auto dataset = group.getDataSet(name); auto dataset = group.getDataSet(name);
Table<DataType> table(dataset.getDimensions()[0], dataset.getDimensions()[1]);
const size_t number_of_columns = dataset.getAttribute("number_of_columns").read<size_t>();
const std::vector<size_t> number_of_rows_per_rank =
dataset.getAttribute("number_of_rows_per_rank").read<std::vector<size_t>>();
std::vector<size_t> offset{0, 0ul};
for (size_t i = 0; i < parallel::rank(); ++i) {
offset[0] += number_of_rows_per_rank[i] * number_of_columns;
}
std::vector<size_t> count{number_of_rows_per_rank[parallel::rank()]};
Table<DataType> table(number_of_rows_per_rank[parallel::rank()], number_of_columns);
if constexpr (std::is_same_v<CellType, data_type>) { if constexpr (std::is_same_v<CellType, data_type>) {
using base_type = std::underlying_type_t<CellType>; using base_type = std::underlying_type_t<CellType>;
dataset.read_raw(reinterpret_cast<base_type*>(&(table(0, 0)))); dataset.select(offset, count).read_raw(reinterpret_cast<base_type*>(get_address(table)));
} else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or
(std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) {
using base_type = typename data_type::base_type; using base_type = typename data_type::base_type;
dataset.read_raw<base_type>(reinterpret_cast<base_type*>(&(table(0, 0)))); dataset.select(offset, count).read_raw<base_type>(reinterpret_cast<base_type*>(get_address(table)));
} else { } else {
dataset.read_raw<data_type>(&(table(0, 0))); dataset.select(offset, count).read_raw<data_type>(get_address(table));
} }
return table; return table;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment