From 96543d0320702cfc5a492dff52361a73b070cf1e Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Sun, 12 May 2024 20:33:40 +0200 Subject: [PATCH] Add checkpoint/resume for ItemArrayVariant --- src/language/modules/MeshModule.cpp | 13 ++++ src/utils/checkpointing/CheckpointUtils.cpp | 41 ++++++++++ src/utils/checkpointing/CheckpointUtils.hpp | 6 ++ src/utils/checkpointing/ResumeUtils.cpp | 84 +++++++++++++++++++++ src/utils/checkpointing/ResumeUtils.hpp | 1 + 5 files changed, 145 insertions(+) diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index 395400f28..cab959a73 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -365,6 +365,19 @@ MeshModule::registerCheckpointResume() const std::function([](const std::string& symbol_name, const HighFive::Group& symbol_table_group) -> EmbeddedData { return readItemType(symbol_name, symbol_table_group); })); + CheckpointResumeRepository::instance() + .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>>, + std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data, + HighFive::File& file, HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) { + writeItemArrayVariant(symbol_name, embedded_data, file, checkpoint_group, + symbol_table_group); + }), + std::function([](const std::string& symbol_name, + const HighFive::Group& symbol_table_group) -> EmbeddedData { + return readItemArrayVariant(symbol_name, symbol_table_group); + })); + CheckpointResumeRepository::instance() .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>>, std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data, diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp index bb370b7fa..5fdeb7793 100644 --- a/src/utils/checkpointing/CheckpointUtils.cpp +++ b/src/utils/checkpointing/CheckpointUtils.cpp @@ -8,6 +8,7 @@ #include <language/utils/DataHandler.hpp> #include <language/utils/OFStream.hpp> #include <language/utils/OStream.hpp> +#include <mesh/ItemArrayVariant.hpp> #include <mesh/ItemType.hpp> #include <mesh/ItemValueVariant.hpp> #include <mesh/Mesh.hpp> @@ -497,6 +498,46 @@ writeIQuadratureDescriptor(const std::string& symbol_name, variable_group.createAttribute("quadrature_degree", iquadrature_descriptor.degree()); } +void +writeItemArrayVariant(HighFive::Group& variable_group, + std::shared_ptr<const ItemArrayVariant> item_array_variant_v, + HighFive::File& file, + HighFive::Group& checkpoint_group) +{ + variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(item_array_variant_v)>)); + + std::visit( + [&](auto&& item_array) { + using ItemArrayT = std::decay_t<decltype(item_array)>; + + variable_group.createAttribute("item_type", ItemArrayT::item_t); + using data_type = std::decay_t<typename ItemArrayT::data_type>; + variable_group.createAttribute("data_type", dataTypeName(ast_node_data_type_from<data_type>)); + + const IConnectivity& connectivity = *item_array.connectivity_ptr(); + variable_group.createAttribute("connectivity_id", connectivity.id()); + writeConnectivity(connectivity, file, checkpoint_group); + + write(variable_group, "arrays", item_array); + }, + item_array_variant_v->itemArray()); +} + +void +writeItemArrayVariant(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) +{ + HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name); + + std::shared_ptr<const ItemArrayVariant> item_array_variant_p = + dynamic_cast<const DataHandler<const ItemArrayVariant>&>(embedded_data.get()).data_ptr(); + + writeItemArrayVariant(variable_group, item_array_variant_p, file, checkpoint_group); +} + void writeItemType(const std::string& symbol_name, const EmbeddedData& embedded_data, diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp index 7edee7b2c..89030bd51 100644 --- a/src/utils/checkpointing/CheckpointUtils.hpp +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -155,6 +155,12 @@ void writeIQuadratureDescriptor(const std::string& symbol_name, HighFive::Group& checkpoint_group, HighFive::Group& symbol_table_group); +void writeItemArrayVariant(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group); + void writeItemType(const std::string& symbol_name, const EmbeddedData& embedded_data, HighFive::File& file, diff --git a/src/utils/checkpointing/ResumeUtils.cpp b/src/utils/checkpointing/ResumeUtils.cpp index d07262779..15c557d67 100644 --- a/src/utils/checkpointing/ResumeUtils.cpp +++ b/src/utils/checkpointing/ResumeUtils.cpp @@ -6,6 +6,7 @@ #include <language/utils/DataHandler.hpp> #include <language/utils/OFStream.hpp> #include <language/utils/SymbolTable.hpp> +#include <mesh/ItemArrayVariant.hpp> #include <mesh/ItemValueVariant.hpp> #include <mesh/NamedBoundaryDescriptor.hpp> #include <mesh/NamedInterfaceDescriptor.hpp> @@ -277,6 +278,89 @@ readIQuadratureDescriptor(const std::string& symbol_name, const HighFive::Group& return {std::make_shared<DataHandler<const IQuadratureDescriptor>>(iquadrature_descrptor)}; } +template <ItemType item_type> +EmbeddedData +readItemArrayVariant(const HighFive::Group& item_array_variant_group) +{ + const std::string data_type = item_array_variant_group.getAttribute("data_type").read<std::string>(); + const size_t connectivity_id = item_array_variant_group.getAttribute("connectivity_id").read<size_t>(); + + const IConnectivity& connectivity = *ResumingData::instance().iConnectivity(connectivity_id); + + std::shared_ptr<ItemArrayVariant> p_item_array; + + if (data_type == dataTypeName(ast_node_data_type_from<bool>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<bool, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<long int>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<long int, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<unsigned long int>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<unsigned long int, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<double>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<double, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<1>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyVector<1>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<2>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyVector<2>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyVector<3>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyVector<3>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<1>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyMatrix<1>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<2>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyMatrix<2>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else if (data_type == dataTypeName(ast_node_data_type_from<TinyMatrix<3>>)) { + p_item_array = std::make_shared<ItemArrayVariant>( + readItemArray<TinyMatrix<3>, item_type>(item_array_variant_group, "arrays", connectivity)); + } else { + throw UnexpectedError("unexpected discrete function data type: " + data_type); + } + return {std::make_shared<DataHandler<const ItemArrayVariant>>(p_item_array)}; +} + +EmbeddedData +readItemArrayVariant(const HighFive::Group& item_array_variant_group) +{ + const ItemType item_type = item_array_variant_group.getAttribute("item_type").read<ItemType>(); + + EmbeddedData embedded_data; + + switch (item_type) { + case ItemType::cell: { + embedded_data = readItemArrayVariant<ItemType::cell>(item_array_variant_group); + break; + } + case ItemType::face: { + embedded_data = readItemArrayVariant<ItemType::face>(item_array_variant_group); + break; + } + case ItemType::edge: { + embedded_data = readItemArrayVariant<ItemType::edge>(item_array_variant_group); + break; + } + case ItemType::node: { + embedded_data = readItemArrayVariant<ItemType::node>(item_array_variant_group); + break; + } + } + + return embedded_data; +} + +EmbeddedData +readItemArrayVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group) +{ + const HighFive::Group item_array_variant_group = symbol_table_group.getGroup("embedded/" + symbol_name); + return readItemArrayVariant(item_array_variant_group); +} + EmbeddedData readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group) { diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp index 419751f5b..dbdc5611d 100644 --- a/src/utils/checkpointing/ResumeUtils.hpp +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -89,6 +89,7 @@ EmbeddedData readIDiscreteFunctionDescriptor(const std::string& symbol_name, con EmbeddedData readIInterfaceDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readINamedDiscreteData(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readIQuadratureDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group); +EmbeddedData readItemArrayVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readItemValueVariant(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group); -- GitLab