diff --git a/src/language/CMakeLists.txt b/src/language/CMakeLists.txt index 2403c045d9e5d26a30f763cc1814ef9f550a4d35..e5e09ed768bd1e8796ad9d4cdc214de65f4cb5b2 100644 --- a/src/language/CMakeLists.txt +++ b/src/language/CMakeLists.txt @@ -10,8 +10,13 @@ add_library( PugsLanguage PugsParser.cpp) +target_link_libraries( + PugsLanguage + ${HIGHFIVE_TARGET} +) + # Additional dependencies -add_dependencies(PugsLanguage +add_dependencies( PugsLanguage PugsLanguageAlgorithms PugsLanguageAST diff --git a/src/language/PugsParser.cpp b/src/language/PugsParser.cpp index b5b446f69ea47d1827c800afd7bfcb87139efe45..cb42fc45acf4f4bfaf14fe2c11a3b0ab84938af6 100644 --- a/src/language/PugsParser.cpp +++ b/src/language/PugsParser.cpp @@ -16,6 +16,7 @@ #include <language/ast/ASTSymbolTableBuilder.hpp> #include <language/utils/ASTCheckpointsInfo.hpp> #include <language/utils/ASTExecutionInfo.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> #include <language/utils/Exit.hpp> #include <language/utils/OperatorRepository.hpp> #include <language/utils/SymbolTable.hpp> @@ -59,6 +60,7 @@ parser(const std::string& filename) auto parse_and_execute = [](auto& input, const std::string& file_content) { OperatorRepository::create(); + CheckpointResumeRepository::create(); ASTExecutionStack::create(input, file_content); std::unique_ptr<ASTNode> root_node = ASTBuilder::build(*input); @@ -105,6 +107,7 @@ parser(const std::string& filename) root_node->m_symbol_table->clearValues(); + CheckpointResumeRepository::destroy(); OperatorRepository::destroy(); }; diff --git a/src/language/ast/ASTModulesImporter.cpp b/src/language/ast/ASTModulesImporter.cpp index 4154a7d8ca2bc874ba7df224c14a23838a21f493..eb03cf56ad57eb53395a47f34f66dc0b4e9782ab 100644 --- a/src/language/ast/ASTModulesImporter.cpp +++ b/src/language/ast/ASTModulesImporter.cpp @@ -27,6 +27,7 @@ ASTModulesImporter::_importModule(ASTNode& import_node) m_module_repository.populateSymbolTable(module_name_node, m_symbol_table); m_module_repository.registerOperators(module_name); + m_module_repository.registerCheckpointResume(module_name); } void diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp index 72b46b30153fe3c6a7291eaef6cedab820231c95..05643f6566aca6728b9cffb8bdd497322c12d313 100644 --- a/src/language/modules/CoreModule.cpp +++ b/src/language/modules/CoreModule.cpp @@ -192,3 +192,7 @@ CoreModule::registerOperators() const UnaryOperatorRegisterForRnxn<2>{}; UnaryOperatorRegisterForRnxn<3>{}; } + +void +CoreModule::registerCheckpointResume() const +{} diff --git a/src/language/modules/CoreModule.hpp b/src/language/modules/CoreModule.hpp index 88c673d65a7d9635aba8698c678dda7bd2081fb1..26059267d445c742894522ff5d0c2101bb15a1c6 100644 --- a/src/language/modules/CoreModule.hpp +++ b/src/language/modules/CoreModule.hpp @@ -13,6 +13,7 @@ class CoreModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; CoreModule(); ~CoreModule() = default; diff --git a/src/language/modules/DevUtilsModule.cpp b/src/language/modules/DevUtilsModule.cpp index a1738294600a5048893691217e01655d308ca6aa..361b4357eda0ebbfe4350cc6c64bbe72cfaa0829 100644 --- a/src/language/modules/DevUtilsModule.cpp +++ b/src/language/modules/DevUtilsModule.cpp @@ -1,6 +1,8 @@ #include <language/modules/DevUtilsModule.hpp> #include <dev/ParallelChecker.hpp> +#include <language/modules/MeshModuleTypes.hpp> +#include <language/modules/SchemeModuleTypes.hpp> #include <language/utils/ASTDotPrinter.hpp> #include <language/utils/ASTExecutionInfo.hpp> #include <language/utils/ASTPrinter.hpp> @@ -9,31 +11,6 @@ #include <fstream> -class DiscreteFunctionVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const DiscreteFunctionVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh"); - -class ItemValueVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_value"); - -class ItemArrayVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_array"); - -class SubItemValuePerItemVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemValuePerItemVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_value"); - -class SubItemArrayPerItemVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemArrayPerItemVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_array"); - DevUtilsModule::DevUtilsModule() { this->_addBuiltinFunction("getAST", std::function( @@ -145,3 +122,7 @@ DevUtilsModule::DevUtilsModule() void DevUtilsModule::registerOperators() const {} + +void +DevUtilsModule::registerCheckpointResume() const +{} diff --git a/src/language/modules/DevUtilsModule.hpp b/src/language/modules/DevUtilsModule.hpp index 96392a8dca7b4ddf1c158d19e12276fd6842b835..1f9eb7ec4d39aa92452998117637dedf98cb38a4 100644 --- a/src/language/modules/DevUtilsModule.hpp +++ b/src/language/modules/DevUtilsModule.hpp @@ -13,6 +13,7 @@ class DevUtilsModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; DevUtilsModule(); ~DevUtilsModule() = default; diff --git a/src/language/modules/IModule.hpp b/src/language/modules/IModule.hpp index 04aefbe91b57ab1925f55bae62ad233c1ea135cf..98fa56c117196931cd86ea49efb9c91236fe1967 100644 --- a/src/language/modules/IModule.hpp +++ b/src/language/modules/IModule.hpp @@ -19,8 +19,8 @@ class IModule using NameTypeMap = std::unordered_map<std::string, std::shared_ptr<TypeDescriptor>>; using NameValueMap = std::unordered_map<std::string, std::shared_ptr<ValueDescriptor>>; - IModule() = default; - IModule(IModule&&) = default; + IModule() = default; + IModule(IModule&&) = default; IModule& operator=(IModule&&) = default; virtual bool isMandatory() const = 0; @@ -31,7 +31,8 @@ class IModule virtual const NameValueMap& getNameValueMap() const = 0; - virtual void registerOperators() const = 0; + virtual void registerOperators() const = 0; + virtual void registerCheckpointResume() const = 0; virtual std::string_view name() const = 0; diff --git a/src/language/modules/LinearSolverModule.cpp b/src/language/modules/LinearSolverModule.cpp index 123a930073dfb0735dd30df1893613b5cb30a292..aef6470d2455909f5ad2d307aa4e5cb859b8bcbb 100644 --- a/src/language/modules/LinearSolverModule.cpp +++ b/src/language/modules/LinearSolverModule.cpp @@ -94,3 +94,7 @@ LinearSolverModule::LinearSolverModule() void LinearSolverModule::registerOperators() const {} + +void +LinearSolverModule::registerCheckpointResume() const +{} diff --git a/src/language/modules/LinearSolverModule.hpp b/src/language/modules/LinearSolverModule.hpp index 7e30c6b2f7c12cf4ede853578e6c064c7dce6292..78bbf46817fbdf4be72702be12312e1fb443504e 100644 --- a/src/language/modules/LinearSolverModule.hpp +++ b/src/language/modules/LinearSolverModule.hpp @@ -13,6 +13,7 @@ class LinearSolverModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; LinearSolverModule(); ~LinearSolverModule() = default; diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp index 4a1e7b35b76a4d9b8b05ebb1355f3cdca4f751d9..afb9dbe0cf16f100e1367c9843325aa5bdec80ad 100644 --- a/src/language/modules/MathModule.cpp +++ b/src/language/modules/MathModule.cpp @@ -105,3 +105,7 @@ MathModule::MathModule() void MathModule::registerOperators() const {} + +void +MathModule::registerCheckpointResume() const +{} diff --git a/src/language/modules/MathModule.hpp b/src/language/modules/MathModule.hpp index c80a74d2da37807750f25a286b9ccd9d12928850..9902f4c7ab8cf2654df062e541344b850fa53d15 100644 --- a/src/language/modules/MathModule.hpp +++ b/src/language/modules/MathModule.hpp @@ -13,6 +13,7 @@ class MathModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; MathModule(); diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index 548f42ddac09caded1533a96c572e9722bf6a63b..cae2399da5002bf356ba33ffe32d899153c1ff89 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -4,6 +4,7 @@ #include <language/node_processor/ExecutionPolicy.hpp> #include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> #include <language/utils/FunctionTable.hpp> #include <language/utils/ItemArrayVariantFunctionInterpoler.hpp> #include <language/utils/ItemValueVariantFunctionInterpoler.hpp> @@ -34,6 +35,8 @@ #include <mesh/SubItemArrayPerItemVariant.hpp> #include <mesh/SubItemValuePerItemVariant.hpp> #include <utils/Exceptions.hpp> +#include <utils/checkpointing/CheckpointUtils.hpp> +#include <utils/checkpointing/ResumeUtils.hpp> #include <Kokkos_Core.hpp> @@ -300,3 +303,20 @@ MeshModule::registerOperators() const BinaryOperatorProcessorBuilder<language::shift_left_op, std::shared_ptr<const OStream>, std::shared_ptr<const OStream>, std::shared_ptr<const MeshVariant>>>()); } + +void +MeshModule::registerCheckpointResume() const +{ + CheckpointResumeRepository::instance() + .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const MeshVariant>>, + std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data, + HighFive::File& file, HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) { + writeMesh(symbol_name, embedded_data, file, checkpoint_group, symbol_table_group); + }), + std::function([](const HighFive::File& file, const HighFive::Group& checkpoint_group, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group) -> EmbeddedData { + return readMesh(file, checkpoint_group, symbol_name, symbol_table_group); + })); +} diff --git a/src/language/modules/MeshModule.hpp b/src/language/modules/MeshModule.hpp index 10bfc9e1cae386702386508d6861e97ba7857857..f9c4ca3de4a520354109d5e64e2364a4b7784c91 100644 --- a/src/language/modules/MeshModule.hpp +++ b/src/language/modules/MeshModule.hpp @@ -2,52 +2,7 @@ #define MESH_MODULE_HPP #include <language/modules/BuiltinModule.hpp> -#include <language/utils/ASTNodeDataTypeTraits.hpp> -#include <utils/PugsMacros.hpp> - -class MeshVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const MeshVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("mesh"); - -class IBoundaryDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary"); - -class IInterfaceDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IInterfaceDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("interface"); - -class IZoneDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IZoneDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("zone"); - -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemType>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_type"); - -class ItemValueVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_value"); - -class ItemArrayVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_array"); - -class SubItemValuePerItemVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemValuePerItemVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_value"); - -class SubItemArrayPerItemVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemArrayPerItemVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_array"); +#include <language/modules/MeshModuleTypes.hpp> class MeshModule : public BuiltinModule { @@ -59,6 +14,7 @@ class MeshModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; MeshModule(); diff --git a/src/language/modules/MeshModuleTypes.hpp b/src/language/modules/MeshModuleTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cfd26a817e872cda2e421adaf25262d982ce0489 --- /dev/null +++ b/src/language/modules/MeshModuleTypes.hpp @@ -0,0 +1,50 @@ +#ifndef MESH_MODULE_TYPES_HPP +#define MESH_MODULE_TYPES_HPP + +#include <language/utils/ASTNodeDataTypeTraits.hpp> + +class MeshVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const MeshVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("mesh"); + +class IBoundaryDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary"); + +class IInterfaceDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IInterfaceDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("interface"); + +class IZoneDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IZoneDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("zone"); + +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemType>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_type"); + +class ItemValueVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_value"); + +class ItemArrayVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_array"); + +class SubItemValuePerItemVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemValuePerItemVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_value"); + +class SubItemArrayPerItemVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const SubItemArrayPerItemVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("sub_item_array"); + +#endif // MESH_MODULE_TYPES_HPP diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index e1e53796f31b117fee9f38c20cfbc6e8ff030349..e291ca761ae2bd2721110f645bb7b93bc395b7b8 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -203,6 +203,17 @@ ModuleRepository::registerOperators(const std::string& module_name) } } +void +ModuleRepository::registerCheckpointResume(const std::string& module_name) +{ + auto i_module = m_module_set.find(module_name); + if (i_module != m_module_set.end()) { + i_module->second->registerCheckpointResume(); + } else { + throw NormalError(std::string{"could not find module "} + module_name); + } +} + std::string ModuleRepository::getModuleInfo(const std::string& module_name) const { diff --git a/src/language/modules/ModuleRepository.hpp b/src/language/modules/ModuleRepository.hpp index c7289c8d1c4b9c7f795925bfdbf65e3f526082e7..3e3d178c37b9b24977599d1a80d1e6f794fb5a96 100644 --- a/src/language/modules/ModuleRepository.hpp +++ b/src/language/modules/ModuleRepository.hpp @@ -35,12 +35,13 @@ class ModuleRepository void populateSymbolTable(const ASTNode& module_name_node, SymbolTable& symbol_table); void populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table); void registerOperators(const std::string& module_name); + void registerCheckpointResume(const std::string& module_name); std::string getAvailableModules() const; std::string getModuleInfo(const std::string& module_name) const; const ModuleRepository& operator=(const ModuleRepository&) = delete; - const ModuleRepository& operator=(ModuleRepository&&) = delete; + const ModuleRepository& operator=(ModuleRepository&&) = delete; ModuleRepository(const ModuleRepository&) = delete; ModuleRepository(ModuleRepository&&) = delete; diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index 7b404fcf670c5217721efba890728ce97e3e91a4..14d45998b80b79b531eef5b1550f0543d77d504e 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -45,6 +45,9 @@ #include <scheme/VariableBCDescriptor.hpp> #include <utils/Socket.hpp> +#include <language/modules/MeshModule.hpp> +#include <language/modules/SocketModule.hpp> + #include <memory> SchemeModule::SchemeModule() @@ -679,3 +682,9 @@ SchemeModule::registerOperators() const BinaryOperatorRegisterForVh{}; UnaryOperatorRegisterForVh{}; } + +void +SchemeModule::registerCheckpointResume() const +{ + throw NotImplementedError("registerCheckpointResume()"); +} diff --git a/src/language/modules/SchemeModule.hpp b/src/language/modules/SchemeModule.hpp index 68cffcf18dc9438cfa5607bdf70739bf223a175d..c61b3084406dbd8ae9889c79b4172195daaa6fe8 100644 --- a/src/language/modules/SchemeModule.hpp +++ b/src/language/modules/SchemeModule.hpp @@ -2,33 +2,7 @@ #define SCHEME_MODULE_HPP #include <language/modules/BuiltinModule.hpp> -#include <language/utils/ASTNodeDataTypeTraits.hpp> -#include <utils/PugsMacros.hpp> - -class IBoundaryConditionDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary_condition"); - -class VariableBCDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const VariableBCDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("variable_boundary_condition"); - -class DiscreteFunctionVariant; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const DiscreteFunctionVariant>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh"); - -class IDiscreteFunctionDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IDiscreteFunctionDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh_type"); - -class IQuadratureDescriptor; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IQuadratureDescriptor>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("quadrature"); +#include <language/modules/SchemeModuleTypes.hpp> class SchemeModule : public BuiltinModule { @@ -42,6 +16,7 @@ class SchemeModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; SchemeModule(); diff --git a/src/language/modules/SchemeModuleTypes.hpp b/src/language/modules/SchemeModuleTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..027dbd4419efed294ced2998573b546f86ae9f93 --- /dev/null +++ b/src/language/modules/SchemeModuleTypes.hpp @@ -0,0 +1,31 @@ +#ifndef SCHEME_MODULE_TYPES_HPP +#define SCHEME_MODULE_TYPES_HPP + +#include <language/utils/ASTNodeDataTypeTraits.hpp> + +class IBoundaryConditionDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("boundary_condition"); + +class VariableBCDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const VariableBCDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("variable_boundary_condition"); + +class DiscreteFunctionVariant; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const DiscreteFunctionVariant>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh"); + +class IDiscreteFunctionDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IDiscreteFunctionDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh_type"); + +class IQuadratureDescriptor; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IQuadratureDescriptor>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("quadrature"); + +#endif // SCHEME_MODULE_TYPES_HPP diff --git a/src/language/modules/SocketModule.cpp b/src/language/modules/SocketModule.cpp index 426a561372289256182d0a5d981eaeaf00de2e08..68ed4c9e57f54ec6966d10404283479860797f32 100644 --- a/src/language/modules/SocketModule.cpp +++ b/src/language/modules/SocketModule.cpp @@ -257,3 +257,9 @@ SocketModule::registerOperators() const std::make_shared<BinaryOperatorProcessorBuilder<language::shift_left_op, std::shared_ptr<const OStream>, std::shared_ptr<const OStream>, std::shared_ptr<const Socket>>>()); } + +void +SocketModule::registerCheckpointResume() const +{ + throw NotImplementedError("registerCheckpointResume()"); +} diff --git a/src/language/modules/SocketModule.hpp b/src/language/modules/SocketModule.hpp index 0b32b64fafc72829e6598cfa78a95ce14f61ed9b..dc14c4b273d3e21c62c6cb42078ad4d2d4212480 100644 --- a/src/language/modules/SocketModule.hpp +++ b/src/language/modules/SocketModule.hpp @@ -2,13 +2,7 @@ #define SOCKET_MODULE_HPP #include <language/modules/BuiltinModule.hpp> -#include <language/utils/ASTNodeDataTypeTraits.hpp> - -class Socket; - -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const Socket>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("socket"); +#include <language/modules/SocketModuleTypes.hpp> class SocketModule : public BuiltinModule { @@ -20,6 +14,7 @@ class SocketModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; SocketModule(); ~SocketModule() = default; diff --git a/src/language/modules/SocketModuleTypes.hpp b/src/language/modules/SocketModuleTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b12370feb7af814e2ee725dd84aca40568a72c72 --- /dev/null +++ b/src/language/modules/SocketModuleTypes.hpp @@ -0,0 +1,11 @@ +#ifndef SOCKET_MODULE_TYPES_HPP +#define SOCKET_MODULE_TYPES_HPP + +#include <language/utils/ASTNodeDataTypeTraits.hpp> + +class Socket; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const Socket>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("socket"); + +#endif // SOCKET_MODULE_TYPES_HPP diff --git a/src/language/modules/WriterModule.cpp b/src/language/modules/WriterModule.cpp index fc9d5703ed793f23c7f9768f4f2137585c3be318..ad48be7f92fd73614af8d0ba7e21697729f07520 100644 --- a/src/language/modules/WriterModule.cpp +++ b/src/language/modules/WriterModule.cpp @@ -16,6 +16,9 @@ #include <output/VTKWriter.hpp> #include <scheme/DiscreteFunctionVariant.hpp> +#include <language/modules/MeshModule.hpp> +#include <language/modules/SchemeModule.hpp> + WriterModule::WriterModule() { this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const INamedDiscreteData>>); @@ -182,3 +185,9 @@ WriterModule::WriterModule() void WriterModule::registerOperators() const {} + +void +WriterModule::registerCheckpointResume() const +{ + throw NotImplementedError("registerCheckpointResume()"); +} diff --git a/src/language/modules/WriterModule.hpp b/src/language/modules/WriterModule.hpp index 97bef5a78ac6f90cff1af303bed901c4c46c665a..7abdba663361ce627db6a406108f3accd4689000 100644 --- a/src/language/modules/WriterModule.hpp +++ b/src/language/modules/WriterModule.hpp @@ -2,22 +2,7 @@ #define WRITER_MODULE_HPP #include <language/modules/BuiltinModule.hpp> -#include <language/utils/ASTNodeDataTypeTraits.hpp> -#include <utils/PugsMacros.hpp> - -class OutputNamedItemValueSet; -class INamedDiscreteData; - -#include <string> - -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const INamedDiscreteData>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("output"); - -class IWriter; -template <> -inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IWriter>> = - ASTNodeDataType::build<ASTNodeDataType::type_id_t>("writer"); +#include <language/modules/WriterModuleTypes.hpp> class WriterModule : public BuiltinModule { @@ -29,6 +14,7 @@ class WriterModule : public BuiltinModule } void registerOperators() const final; + void registerCheckpointResume() const final; WriterModule(); diff --git a/src/language/modules/WriterModuleTypes.hpp b/src/language/modules/WriterModuleTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..428dbbaa2bc73025b2b8d6ea440a58ef8eb884f9 --- /dev/null +++ b/src/language/modules/WriterModuleTypes.hpp @@ -0,0 +1,18 @@ +#ifndef WRITER_MODULE_TYPES_HPP +#define WRITER_MODULE_TYPES_HPP + +#include <language/utils/ASTNodeDataTypeTraits.hpp> + +class OutputNamedItemValueSet; +class INamedDiscreteData; + +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const INamedDiscreteData>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("output"); + +class IWriter; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const IWriter>> = + ASTNodeDataType::build<ASTNodeDataType::type_id_t>("writer"); + +#endif // WRITER_MODULE_TYPES_HPP diff --git a/src/language/node_processor/BinaryExpressionProcessor.hpp b/src/language/node_processor/BinaryExpressionProcessor.hpp index d91f39d905c1f055ed7d95d6a15ff7f258fc2385..5688f86db6c5cfb34bf3b5f2eb10818b1b56f4a9 100644 --- a/src/language/node_processor/BinaryExpressionProcessor.hpp +++ b/src/language/node_processor/BinaryExpressionProcessor.hpp @@ -8,7 +8,7 @@ #include <type_traits> -template <typename DataType> +template <typename DataT> class DataHandler; template <typename Op> diff --git a/src/language/utils/ASTNodeDataTypeTraits.hpp b/src/language/utils/ASTNodeDataTypeTraits.hpp index 2ba019646dac53291dc73169fadc47e737f68bad..7d4ad7852d6f90f9b1efb0384e4fbd20fd242d82 100644 --- a/src/language/utils/ASTNodeDataTypeTraits.hpp +++ b/src/language/utils/ASTNodeDataTypeTraits.hpp @@ -9,7 +9,15 @@ #include <vector> template <typename T> -inline ASTNodeDataType ast_node_data_type_from = ASTNodeDataType{}; +inline ASTNodeDataType +_ast_node_data_type_undefined() +{ + constexpr bool type_is_undefined = not std::is_same_v<T, T>; + static_assert(type_is_undefined, "Module header defining this data type must be included"); + return {}; +} +template <typename T> +inline ASTNodeDataType ast_node_data_type_from = _ast_node_data_type_undefined<T>(); template <> inline ASTNodeDataType ast_node_data_type_from<void> = ASTNodeDataType::build<ASTNodeDataType::void_t>(); diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 0455cd053bb57ed70a7c534dec1f172172a79963..08ea07e3326921adb268edd70122c7c1b5251434 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(PugsLanguageUtils BinaryOperatorRegisterForString.cpp BinaryOperatorRegisterForZ.cpp BuiltinFunctionEmbedderUtils.cpp + CheckpointResumeRepository.cpp DataVariant.cpp EmbeddedData.cpp EmbeddedDiscreteFunctionMathFunctions.cpp diff --git a/src/language/utils/CheckpointResumeRepository.cpp b/src/language/utils/CheckpointResumeRepository.cpp new file mode 100644 index 0000000000000000000000000000000000000000..180f60b66100898b4b5033230956e48a4e8e836d --- /dev/null +++ b/src/language/utils/CheckpointResumeRepository.cpp @@ -0,0 +1,58 @@ +#include <language/utils/CheckpointResumeRepository.hpp> + +CheckpointResumeRepository* CheckpointResumeRepository::m_instance = nullptr; + +void +CheckpointResumeRepository::checkpoint(const ASTNodeDataType& data_type, + const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) const +{ + std::string data_type_name = dataTypeName(data_type); + if (auto i_dt_function = m_data_type_checkpointing.find(data_type_name); + i_dt_function != m_data_type_checkpointing.end()) { + const CheckpointFunction& function = i_dt_function->second; + function(symbol_name, embedded_data, file, checkpoint_group, symbol_table_group); + } else { + std::ostringstream error_msg; + error_msg << "cannot find checkpointing function for type '" << rang::fgB::yellow << data_type_name + << rang::fg::reset << "'"; + throw UnexpectedError(error_msg.str()); + } +} + +EmbeddedData +CheckpointResumeRepository::resume(const HighFive::File& file, + const HighFive::Group& checkpoint_group, + const ASTNodeDataType& data_type, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group) const +{ + std::string data_type_name = dataTypeName(data_type); + if (auto i_dt_function = m_data_type_resuming.find(data_type_name); i_dt_function != m_data_type_resuming.end()) { + const ResumeFunction& function = i_dt_function->second; + return function(file, checkpoint_group, symbol_name, symbol_table_group); + } else { + std::ostringstream error_msg; + error_msg << "cannot find resuming function for type '" << rang::fgB::yellow << data_type_name << rang::fg::reset + << "'"; + throw UnexpectedError(error_msg.str()); + } +} + +void +CheckpointResumeRepository::create() +{ + Assert(m_instance == nullptr, "CheckpointResumeRepository was already created"); + m_instance = new CheckpointResumeRepository; +} + +void +CheckpointResumeRepository::destroy() +{ + Assert(m_instance != nullptr, "CheckpointResumeRepository was not created"); + delete m_instance; + m_instance = nullptr; +} diff --git a/src/language/utils/CheckpointResumeRepository.hpp b/src/language/utils/CheckpointResumeRepository.hpp new file mode 100644 index 0000000000000000000000000000000000000000..491c235901fa58b690db28a85603b536a6a26b2f --- /dev/null +++ b/src/language/utils/CheckpointResumeRepository.hpp @@ -0,0 +1,88 @@ +#ifndef CHECKPOINT_RESUME_REPOSITORY_HPP +#define CHECKPOINT_RESUME_REPOSITORY_HPP + +#include <language/utils/SymbolTable.hpp> +#include <utils/Exceptions.hpp> +#include <utils/HighFivePugsUtils.hpp> +#include <utils/PugsAssert.hpp> +#include <utils/PugsUtils.hpp> + +#include <string> +#include <unordered_map> + +class CheckpointResumeRepository +{ + public: + using CheckpointFunction = std::function<void(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group)>; + + using ResumeFunction = std::function<EmbeddedData(const HighFive::File& symbol, + const HighFive::Group& file, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group)>; + + private: + std::unordered_map<std::string, CheckpointFunction> m_data_type_checkpointing; + std::unordered_map<std::string, ResumeFunction> m_data_type_resuming; + + public: + void + addCheckpointResume(const ASTNodeDataType& ast_node_data_type, + CheckpointFunction&& checkpoint_function, + ResumeFunction&& resume_function) + { + const std::string& data_type_name = dataTypeName(ast_node_data_type); + { + const auto [i, inserted] = + m_data_type_checkpointing.insert({data_type_name, std::forward<CheckpointFunction>(checkpoint_function)}); + if (not(inserted)) { + std::ostringstream error_msg; + error_msg << "checkpointing for type '" << rang::fgB::yellow << data_type_name << rang::fg::reset + << "' has already be defined"; + throw UnexpectedError(error_msg.str()); + } + } + { + const auto [i, inserted] = + m_data_type_resuming.insert({data_type_name, std::forward<ResumeFunction>(resume_function)}); + Assert(inserted); + } + } + + void checkpoint(const ASTNodeDataType& data_type, + const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) const; + + EmbeddedData resume(const HighFive::File& file, + const HighFive::Group& checkpoint_group, + const ASTNodeDataType& data_type, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group) const; + + static void create(); + + PUGS_INLINE + static CheckpointResumeRepository& + instance() + { + Assert(m_instance != nullptr); + return *m_instance; + } + + static void destroy(); + + private: + static CheckpointResumeRepository* m_instance; + + CheckpointResumeRepository() = default; + + ~CheckpointResumeRepository() = default; +}; + +#endif // CHECKPOINT_RESUME_REPOSITORY_HPP diff --git a/src/language/utils/DataHandler.hpp b/src/language/utils/DataHandler.hpp index c09256e2a8379cbc934394304fd35d1f98610348..66ff5fcb0c0113393cc7b45034be5beaf50ae50e 100644 --- a/src/language/utils/DataHandler.hpp +++ b/src/language/utils/DataHandler.hpp @@ -16,6 +16,9 @@ class IDataHandler virtual ~IDataHandler() = default; }; +template <typename DataT> +void checkpointStore(const DataT&); + template <typename DataT> class DataHandler : public IDataHandler { @@ -29,6 +32,8 @@ class DataHandler : public IDataHandler return m_data; } + friend void checkpointStore<DataT>(const DataT&); + DataHandler(std::shared_ptr<DataT> data) : m_data(data) {} ~DataHandler() = default; }; diff --git a/src/language/utils/OStream.hpp b/src/language/utils/OStream.hpp index bef02277ef8e07be09cc5731b33c6f5c3363e37d..24b1e62dd0a2232829ab896af036a1bdcc95edfb 100644 --- a/src/language/utils/OStream.hpp +++ b/src/language/utils/OStream.hpp @@ -40,6 +40,8 @@ class OStream virtual ~OStream() = default; }; +void checkpointStore(const OStream&); + template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const OStream>> = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("ostream"); diff --git a/src/utils/GlobalVariableManager.hpp b/src/utils/GlobalVariableManager.hpp index f720252fdd7c04b1ddd8e234d890bf5b1956c899..fa7f52a3db7f2dfedfe797cc81a5b7ee8fd40525 100644 --- a/src/utils/GlobalVariableManager.hpp +++ b/src/utils/GlobalVariableManager.hpp @@ -18,6 +18,13 @@ class GlobalVariableManager ~GlobalVariableManager() = default; public: + PUGS_INLINE + size_t + getConnectivityId() const + { + return m_connectivity_id; + } + PUGS_INLINE size_t getAndIncrementConnectivityId() @@ -25,6 +32,13 @@ class GlobalVariableManager return m_connectivity_id++; } + PUGS_INLINE + size_t + getMeshId() const + { + return m_mesh_id; + } + PUGS_INLINE size_t getAndIncrementMeshId() diff --git a/src/utils/PugsUtils.cpp b/src/utils/PugsUtils.cpp index f3ae399b052c15a87ab9f9ea09a385fad54b7c3e..ed6f6803de0ace073192f155fbef1883bfe8e7a4 100644 --- a/src/utils/PugsUtils.cpp +++ b/src/utils/PugsUtils.cpp @@ -13,6 +13,7 @@ #include <utils/RevisionInfo.hpp> #include <utils/SLEPcWrapper.hpp> #include <utils/SignalManager.hpp> +#include <utils/checkpointing/PrintCheckpointInfo.hpp> #include <utils/checkpointing/ResumingManager.hpp> #include <utils/pugs_build_info.hpp> @@ -90,6 +91,8 @@ initialize(int& argc, char* argv[]) bool enable_signals = true; int nb_threads = -1; + bool print_checkpoint_info = false; + ParallelChecker::Mode pc_mode = ParallelChecker::Mode::automatic; std::string pc_filename = ParallelChecker::instance().filename(); @@ -102,6 +105,8 @@ initialize(int& argc, char* argv[]) bool is_resuming = false; app.add_flag("--resume", is_resuming, "Resume at checkpoint"); + app.add_flag("--print-checkpoint-info", print_checkpoint_info, "Print checkpoint info and exit"); + app.set_version_flag("-v,--version", []() { ConsoleManager::init(true); std::stringstream os; @@ -220,6 +225,12 @@ initialize(int& argc, char* argv[]) std::cout << "-------------------------------------------------------\n"; } + if (print_checkpoint_info) { + printCheckpointInfo(filename); + finalize(); + std::exit(0); + } + return filename; } diff --git a/src/utils/checkpointing/CMakeLists.txt b/src/utils/checkpointing/CMakeLists.txt index a2a0cdc827bf3bf4c866419fa6ea0d4a15058121..aa95560ee62bf433341165a29122f6620e82103f 100644 --- a/src/utils/checkpointing/CMakeLists.txt +++ b/src/utils/checkpointing/CMakeLists.txt @@ -3,7 +3,11 @@ add_library( PugsCheckpointing Checkpoint.cpp + CheckpointUtils.cpp + PrintCheckpointInfo.cpp Resume.cpp + ResumeUtils.cpp + ResumingData.cpp ResumingManager.cpp ResumingUtils.cpp) diff --git a/src/utils/checkpointing/Checkpoint.cpp b/src/utils/checkpointing/Checkpoint.cpp index 07fae9f4f93006bb34faa2e80c6a62e758945955..15c5f1bb242e298762bc492a1a840ed39d729232 100644 --- a/src/utils/checkpointing/Checkpoint.cpp +++ b/src/utils/checkpointing/Checkpoint.cpp @@ -10,6 +10,7 @@ #include <language/utils/SymbolTable.hpp> #include <iostream> +#include <map> #endif // PUGS_HAS_HDF5 #include <language/utils/ASTCheckpointsInfo.hpp> @@ -17,90 +18,138 @@ #include <utils/checkpointing/ResumingManager.hpp> #ifdef PUGS_HAS_HDF5 + +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/DataHandler.hpp> +#include <mesh/MeshVariant.hpp> +#include <utils/GlobalVariableManager.hpp> +#include <utils/RandomEngine.hpp> + +#include <language/utils/CheckpointResumeRepository.hpp> + void checkpoint() { - auto create_props = HighFive::FileCreateProps{}; - create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); + try { + auto create_props = HighFive::FileCreateProps{}; + create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0)); - uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); + uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber(); - const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; + const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite; - HighFive::File file("checkpoint.h5", file_openmode, create_props); + HighFive::File file("checkpoint.h5", file_openmode, create_props); - HighFive::Group checkpoint = file.createGroup("checkpoint_" + std::to_string(checkpoint_number)); + std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number); - uint64_t checkpoint_id = - ASTCheckpointsInfo::getInstance().getCheckpointId((ASTExecutionStack::getInstance().currentNode())); + HighFive::Group checkpoint = file.createGroup(checkpoint_name); - checkpoint.createAttribute("checkpoint_id", checkpoint_id); - checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); + uint64_t checkpoint_id = + ASTCheckpointsInfo::getInstance().getCheckpointId((ASTExecutionStack::getInstance().currentNode())); - std::shared_ptr<const SymbolTable> p_symbol_table = ASTExecutionStack::getInstance().currentNode().m_symbol_table; - auto symbol_table_group = checkpoint; - while (p_symbol_table.use_count() > 0) { - symbol_table_group = symbol_table_group.createGroup("symbol table"); + std::string time = [] { + std::ostringstream os; + auto t = std::time(nullptr); + os << std::put_time(std::localtime(&t), "%c"); + return os.str(); + }(); - const SymbolTable& symbol_table = *p_symbol_table; + checkpoint.createAttribute("creation_date", time); + checkpoint.createAttribute("name", checkpoint_name); + checkpoint.createAttribute("id", checkpoint_id); + checkpoint.createDataSet("data.pgs", ASTExecutionStack::getInstance().fileContent()); - const auto& symbol_list = symbol_table.symbolList(); + { + HighFive::Group random_seed = checkpoint.createGroup("singleton/random_seed"); + random_seed.createAttribute("current_seed", RandomEngine::instance().getCurrentSeed()); + } + { + HighFive::Group global_variables = checkpoint.createGroup("singleton/global_variables"); + global_variables.createAttribute("connectivity_id", GlobalVariableManager::instance().getConnectivityId()); + global_variables.createAttribute("mesh_id", GlobalVariableManager::instance().getMeshId()); + } + { + std::cout << rang::fgB::magenta << "Checkpoint DualConnectivityManager NIY" << rang::fg::reset << '\n'; + std::cout << rang::fgB::magenta << "Checkpoint DualMeshManager NIY" << rang::fg::reset << '\n'; + } - for (auto& symbol : symbol_list) { - switch (symbol.attributes().dataType()) { - case ASTNodeDataType::builtin_function_t: - case ASTNodeDataType::function_t: - case ASTNodeDataType::type_name_id_t: { - break; - } - default: { - if ((symbol_table.has(symbol.name(), ASTExecutionStack::getInstance().currentNode().begin())) and - (symbol.attributes().dataType() != ASTNodeDataType::builtin_function_t) and - (not symbol.attributes().isModuleVariable())) { - std::visit( - [&](auto&& data) { - using DataT = std::decay_t<decltype(data)>; - if constexpr (std::is_same_v<DataT, std::monostate>) { - } else if constexpr ((std::is_arithmetic_v<DataT>) or (std::is_same_v<DataT, std::string>) or - (is_tiny_vector_v<DataT>) or (is_tiny_matrix_v<DataT>)) { - symbol_table_group.createAttribute(symbol.name(), data); - } else if constexpr (is_std_vector_v<DataT>) { - using value_type = typename DataT::value_type; - if constexpr ((std::is_arithmetic_v<value_type>) or (std::is_same_v<value_type, std::string>) or - (is_tiny_vector_v<value_type>) or (is_tiny_matrix_v<value_type>)) { + std::shared_ptr<const SymbolTable> p_symbol_table = ASTExecutionStack::getInstance().currentNode().m_symbol_table; + auto symbol_table_group = checkpoint; + while (p_symbol_table.use_count() > 0) { + symbol_table_group = symbol_table_group.createGroup("symbol table"); + + const SymbolTable& symbol_table = *p_symbol_table; + + const auto& symbol_list = symbol_table.symbolList(); + + for (auto& symbol : symbol_list) { + switch (symbol.attributes().dataType()) { + case ASTNodeDataType::builtin_function_t: + case ASTNodeDataType::function_t: + case ASTNodeDataType::type_name_id_t: { + break; + } + default: { + if ((symbol_table.has(symbol.name(), ASTExecutionStack::getInstance().currentNode().begin())) and + (symbol.attributes().dataType() != ASTNodeDataType::builtin_function_t) and + (not symbol.attributes().isModuleVariable())) { + std::visit( + [&](auto&& data) { + using DataT = std::decay_t<decltype(data)>; + if constexpr (std::is_same_v<DataT, std::monostate>) { + } else if constexpr ((std::is_arithmetic_v<DataT>) or (std::is_same_v<DataT, std::string>) or + (is_tiny_vector_v<DataT>) or (is_tiny_matrix_v<DataT>)) { symbol_table_group.createAttribute(symbol.name(), data); + } else if constexpr (std::is_same_v<DataT, EmbeddedData>) { + CheckpointResumeRepository::instance().checkpoint(symbol.attributes().dataType(), symbol.name(), data, + file, checkpoint, symbol_table_group); + } else if constexpr (is_std_vector_v<DataT>) { + using value_type = typename DataT::value_type; + if constexpr ((std::is_arithmetic_v<value_type>) or (std::is_same_v<value_type, std::string>) or + (is_tiny_vector_v<value_type>) or (is_tiny_matrix_v<value_type>)) { + symbol_table_group.createAttribute(symbol.name(), data); + } else if constexpr (std::is_same_v<value_type, EmbeddedData>) { + for (size_t i = 0; i < data.size(); ++i) { + CheckpointResumeRepository::instance().checkpoint(symbol.attributes().dataType().contentType(), + symbol.name() + "/" + std::to_string(i), + data[i], file, checkpoint, symbol_table_group); + } + } else { + throw UnexpectedError("unexpected data type"); + } } else { - throw NotImplementedError("datatype is not handled yet"); + throw UnexpectedError("unexpected data type"); } - } else { - throw NotImplementedError("datatype is not handled yet"); - } - }, - symbol.attributes().value()); + }, + symbol.attributes().value()); + } + } } } - } + + p_symbol_table = symbol_table.parentTable(); } - p_symbol_table = symbol_table.parentTable(); - } + if (file.exist("last_checkpoint")) { + file.unlink("last_checkpoint"); + } + file.createHardLink("last_checkpoint", checkpoint); - if (file.exist("last_checkpoint")) { - file.unlink("last_checkpoint"); - } - file.createHardLink("last_checkpoint", checkpoint); + if (file.exist("resuming_checkpoint")) { + file.unlink("resuming_checkpoint"); + } + file.createHardLink("resuming_checkpoint", checkpoint); - if (file.exist("resuming_checkpoint")) { - file.unlink("resuming_checkpoint"); - } - file.createHardLink("resuming_checkpoint", checkpoint); + if (file.hasAttribute("checkpoint_number")) { + file.deleteAttribute("checkpoint_number"); + } + file.createAttribute("checkpoint_number", checkpoint_number); - if (file.hasAttribute("checkpoint_number")) { - file.deleteAttribute("checkpoint_number"); + ++checkpoint_number; + } + catch (HighFive::Exception& e) { + throw NormalError(e.what()); } - file.createAttribute("checkpoint_number", checkpoint_number); - - ++checkpoint_number; } #else // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..aace16bdd2093c011f0d2ce349fa6ceaf97d8fae --- /dev/null +++ b/src/utils/checkpointing/CheckpointUtils.cpp @@ -0,0 +1,180 @@ +#include <utils/checkpointing/CheckpointUtils.hpp> + +#include <language/modules/MeshModuleTypes.hpp> +#include <language/utils/ASTNodeDataTypeTraits.hpp> +#include <language/utils/DataHandler.hpp> +#include <mesh/MeshVariant.hpp> +#include <utils/checkpointing/RefItemListHFType.hpp> + +#include <mesh/Mesh.hpp> + +template <ItemType item_type, size_t Dimension> +void +writeRefItemList(const Connectivity<Dimension>& connectivity, HighFive::Group& connectivity_group) +{ + for (size_t i_item_list = 0; i_item_list < connectivity.template numberOfRefItemList<item_type>(); ++i_item_list) { + auto ref_item_list = connectivity.template refItemList<item_type>(i_item_list); + + std::ostringstream ref_item_list_group_name; + ref_item_list_group_name << "item_ref_list/" << itemName(item_type) << '/' << ref_item_list.refId().tagName(); + HighFive::Group ref_item_list_group = connectivity_group.createGroup(ref_item_list_group_name.str()); + ref_item_list_group.createAttribute("tag_name", ref_item_list.refId().tagName()); + ref_item_list_group.createAttribute("tag_number", ref_item_list.refId().tagNumber()); + ref_item_list_group.createAttribute("type", ref_item_list.type()); + + write(ref_item_list_group, "list", ref_item_list.list()); + } +} + +template <size_t Dimension> +void +writeConnectivity(const Connectivity<Dimension>& connectivity, HighFive::File& file, HighFive::Group& checkpoint_group) +{ + std::string connectivity_group_name = "connectivity/" + std::to_string(connectivity.id()); + if (not checkpoint_group.exist(connectivity_group_name)) { + bool linked = false; + for (auto group_name : file.listObjectNames()) { + if (file.exist(group_name + "/" + connectivity_group_name)) { + checkpoint_group.createHardLink(connectivity_group_name, + file.getGroup(group_name + "/" + connectivity_group_name)); + linked = true; + break; + } + } + + if (not linked) { + HighFive::Group connectivity_group = checkpoint_group.createGroup(connectivity_group_name); + + connectivity_group.createAttribute("dimension", connectivity.dimension()); + connectivity_group.createAttribute("id", connectivity.id()); + connectivity_group.createAttribute("type", std::string{"unstructured"}); + + write(connectivity_group, "cell_to_node_matrix_values", + connectivity.getMatrix(ItemType::cell, ItemType::node).values()); + write(connectivity_group, "cell_to_node_matrix_rowsMap", + connectivity.getMatrix(ItemType::cell, ItemType::node).rowsMap()); + + if constexpr (Dimension > 1) { + write(connectivity_group, "cell_to_face_matrix_values", + connectivity.getMatrix(ItemType::cell, ItemType::face).values()); + write(connectivity_group, "cell_to_face_matrix_rowsMap", + connectivity.getMatrix(ItemType::cell, ItemType::face).rowsMap()); + + write(connectivity_group, "face_to_node_matrix_values", + connectivity.getMatrix(ItemType::face, ItemType::node).values()); + write(connectivity_group, "face_to_node_matrix_rowsMap", + connectivity.getMatrix(ItemType::face, ItemType::node).rowsMap()); + + write(connectivity_group, "node_to_face_matrix_values", + connectivity.getMatrix(ItemType::node, ItemType::face).values()); + write(connectivity_group, "node_to_face_matrix_rowsMap", + connectivity.getMatrix(ItemType::node, ItemType::face).rowsMap()); + + write(connectivity_group, "cell_face_is_reversed", connectivity.cellFaceIsReversed().arrayView()); + } + + if constexpr (Dimension > 2) { + write(connectivity_group, "cell_to_edge_matrix_values", + connectivity.getMatrix(ItemType::cell, ItemType::edge).values()); + write(connectivity_group, "cell_to_edge_matrix_rowsMap", + connectivity.getMatrix(ItemType::cell, ItemType::edge).rowsMap()); + + write(connectivity_group, "face_to_edge_matrix_values", + connectivity.getMatrix(ItemType::face, ItemType::edge).values()); + write(connectivity_group, "face_to_edge_matrix_rowsMap", + connectivity.getMatrix(ItemType::face, ItemType::edge).rowsMap()); + + write(connectivity_group, "edge_to_node_matrix_values", + connectivity.getMatrix(ItemType::edge, ItemType::node).values()); + write(connectivity_group, "edge_to_node_matrix_rowsMap", + connectivity.getMatrix(ItemType::edge, ItemType::node).rowsMap()); + + write(connectivity_group, "node_to_edge_matrix_values", + connectivity.getMatrix(ItemType::node, ItemType::edge).values()); + write(connectivity_group, "node_to_edge_matrix_rowsMap", + connectivity.getMatrix(ItemType::node, ItemType::edge).rowsMap()); + + write(connectivity_group, "face_edge_is_reversed", connectivity.faceEdgeIsReversed().arrayView()); + } + + write(connectivity_group, "cell_type", connectivity.cellType()); + + write(connectivity_group, "cell_numbers", connectivity.cellNumber()); + write(connectivity_group, "node_numbers", connectivity.nodeNumber()); + + write(connectivity_group, "cell_owner", connectivity.cellOwner()); + write(connectivity_group, "node_owner", connectivity.nodeOwner()); + + if constexpr (Dimension > 1) { + write(connectivity_group, "face_numbers", connectivity.faceNumber()); + + write(connectivity_group, "face_owner", connectivity.faceOwner()); + } + if constexpr (Dimension > 2) { + write(connectivity_group, "edge_numbers", connectivity.edgeNumber()); + + write(connectivity_group, "edge_owner", connectivity.edgeOwner()); + } + + writeRefItemList<ItemType::cell>(connectivity, connectivity_group); + writeRefItemList<ItemType::face>(connectivity, connectivity_group); + writeRefItemList<ItemType::edge>(connectivity, connectivity_group); + writeRefItemList<ItemType::node>(connectivity, connectivity_group); + } + } +} + +void +writeMesh(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) +{ + HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name); + + std::shared_ptr<const MeshVariant> mesh_v = + dynamic_cast<const DataHandler<const MeshVariant>&>(embedded_data.get()).data_ptr(); + + variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(mesh_v)>)); + variable_group.createAttribute("id", mesh_v->id()); + + std::string mesh_group_name = "mesh/" + std::to_string(mesh_v->id()); + if (not checkpoint_group.exist(mesh_group_name)) { + bool linked = false; + for (auto group_name : file.listObjectNames()) { + if (file.exist(group_name + "/" + mesh_group_name)) { + checkpoint_group.createHardLink(mesh_group_name, file.getGroup(group_name + "/" + mesh_group_name)); + linked = true; + break; + } + } + + if (not linked) { + HighFive::Group mesh_group = checkpoint_group.createGroup(mesh_group_name); + mesh_group.createAttribute("connectivity", mesh_v->connectivity().id()); + std::visit( + [&](auto&& mesh) { + using MeshType = mesh_type_t<decltype(mesh)>; + if constexpr (is_polygonal_mesh_v<MeshType>) { + mesh_group.createAttribute("id", mesh->id()); + mesh_group.createAttribute("type", std::string{"polygonal"}); + mesh_group.createAttribute("dimension", mesh->dimension()); + write(mesh_group, "xr", mesh->xr()); + } else { + throw UnexpectedError("unexpected mesh type"); + } + }, + mesh_v->variant()); + } + } + + std::visit( + [&](auto&& mesh) { + using MeshType = mesh_type_t<decltype(mesh)>; + if constexpr (is_polygonal_mesh_v<MeshType>) { + writeConnectivity(mesh->connectivity(), file, checkpoint_group); + } + }, + mesh_v->variant()); +} diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9b79c8ec9c898c96ea03c252d756d70b9d9ea388 --- /dev/null +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -0,0 +1,51 @@ +#ifndef CHECKPOINT_UTILS_HPP +#define CHECKPOINT_UTILS_HPP + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/utils/SymbolTable.hpp> +#include <mesh/CellType.hpp> +#include <mesh/ItemValue.hpp> + +template <typename DataType> +PUGS_INLINE void +write(HighFive::Group& group, const std::string& name, const Array<DataType>& array) +{ + using data_type = std::remove_const_t<DataType>; + HighFive::DataSetCreateProps properties; + properties.add(HighFive::Chunking(std::vector<hsize_t>{std::min(4ul * 1024ul * 1024ul, array.size())})); + properties.add(HighFive::Shuffle()); + properties.add(HighFive::Deflate(3)); + + if constexpr (std::is_same_v<CellType, data_type>) { + auto dataset = group.createDataSet<short>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); + dataset.template write_raw<short>(reinterpret_cast<const short*>(&(array[0]))); + } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or + (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { + using base_type = typename data_type::base_type; + auto dataset = + group.createDataSet<base_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); + dataset.template write_raw<base_type>(reinterpret_cast<const base_type*>(&(array[0]))); + } else { + auto dataset = + group.createDataSet<data_type>(name, HighFive::DataSpace{std::vector<size_t>{array.size()}}, properties); + dataset.template write_raw<data_type>(&(array[0])); + } +} + +template <typename DataType, ItemType item_type, typename ConnectivityPtr> +PUGS_INLINE void +write(HighFive::Group& group, + const std::string& name, + const ItemValue<DataType, item_type, ConnectivityPtr>& item_value) +{ + write(group, name, item_value.arrayView()); +} + +void writeMesh(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group); + +#endif // CHECKPOINT_UTILS_HPP diff --git a/src/utils/checkpointing/PrintCheckpointInfo.cpp b/src/utils/checkpointing/PrintCheckpointInfo.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30ee76bca359c55f3d2adc0f5640bda0853ae37e --- /dev/null +++ b/src/utils/checkpointing/PrintCheckpointInfo.cpp @@ -0,0 +1,183 @@ +#include <utils/checkpointing/PrintCheckpointInfo.hpp> + +#include <utils/Exceptions.hpp> +#include <utils/Messenger.hpp> +#include <utils/pugs_config.hpp> + +#include <rang.hpp> + +#ifdef PUGS_HAS_HDF5 + +#include <utils/HighFivePugsUtils.hpp> + +#include <algebra/TinyMatrix.hpp> +#include <algebra/TinyVector.hpp> + +#include <iostream> +#include <regex> +#endif // PUGS_HAS_HDF5 + +#ifdef PUGS_HAS_HDF5 + +template <typename T> +void +printAttributeValue(const HighFive::Attribute& attribute) +{ + std::string delim = ""; + + if constexpr (std::is_same_v<T, std::string>) { + delim = "\""; + } + + HighFive::DataSpace data_space = attribute.getSpace(); + if (data_space.getNumberDimensions() == 0) { + std::cout << std::boolalpha << delim << attribute.read<T>() << delim; + } else if (data_space.getNumberDimensions() == 1) { + std::vector value = attribute.read<std::vector<T>>(); + if (value.size() > 0) { + std::cout << '(' << std::boolalpha << delim << value[0] << delim; + for (size_t i = 1; i < value.size(); ++i) { + std::cout << ", " << std::boolalpha << delim << value[i] << delim; + } + std::cout << ')'; + } + } +} + +void +printCheckpointInfo(const std::string& filename) +{ + if (parallel::rank() == 0) { + try { + HighFive::File file(filename, HighFive::File::ReadOnly); + + std::map<size_t, std::string> checkpoint_name_list; + + for (auto name : file.listObjectNames()) { + std::smatch number_string; + const std::regex checkpoint_regex("checkpoint_([0-9]+)"); + if (std::regex_match(name, number_string, checkpoint_regex)) { + std::stringstream os; + os << number_string[1].str(); + + size_t id = 0; + os >> id; + + checkpoint_name_list[id] = name; + } + } + + for (auto&& [id, checkpoint_name] : checkpoint_name_list) { + HighFive::Group checkpoint = file.getGroup(checkpoint_name); + const std::string creation_date = checkpoint.getAttribute("creation_date").read<std::string>(); + + std::cout << rang::fgB::yellow << " * " << rang::fg::reset << rang::fgB::magenta << checkpoint_name + << rang::fg::reset << " [" << rang::fg::green << creation_date << rang::fg::reset << "]\n"; + + HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table"); + + bool finished = true; + do { + finished = true; + + for (auto symbol_name : saved_symbol_table.listAttributeNames()) { + HighFive::Attribute attribute = saved_symbol_table.getAttribute(symbol_name); + HighFive::DataType data_type = attribute.getDataType(); + + std::cout << " "; + std::cout << std::setw(25) << std::setfill('.') // << rang::style::bold; + << std::left << symbol_name + ' ' << std::setfill(' '); + std::cout << ' '; + + switch (data_type.getClass()) { + case HighFive::DataTypeClass::Float: { + printAttributeValue<double>(attribute); + break; + } + case HighFive::DataTypeClass::Integer: { + if (data_type == HighFive::AtomicType<uint64_t>()) { + printAttributeValue<uint64_t>(attribute); + } else if (data_type == HighFive::AtomicType<int64_t>()) { + printAttributeValue<int64_t>(attribute); + } + break; + } + case HighFive::DataTypeClass::Array: { + HighFive::DataSpace data_space = attribute.getSpace(); + + if (data_type == HighFive::AtomicType<TinyVector<1>>()) { + printAttributeValue<TinyVector<1>>(attribute); + } else if (data_type == HighFive::AtomicType<TinyVector<2>>()) { + printAttributeValue<TinyVector<2>>(attribute); + } else if (data_type == HighFive::AtomicType<TinyVector<3>>()) { + printAttributeValue<TinyVector<3>>(attribute); + } else if (data_type == HighFive::AtomicType<TinyMatrix<1>>()) { + printAttributeValue<TinyMatrix<1>>(attribute); + } else if (data_type == HighFive::AtomicType<TinyMatrix<2>>()) { + printAttributeValue<TinyMatrix<2>>(attribute); + } else if (data_type == HighFive::AtomicType<TinyMatrix<3>>()) { + printAttributeValue<TinyMatrix<3>>(attribute); + } + break; + } + case HighFive::DataTypeClass::Enum: { + if (data_type == HighFive::create_datatype<bool>()) { + printAttributeValue<bool>(attribute); + } else { + std::cout << "????"; + } + break; + } + case HighFive::DataTypeClass::String: { + printAttributeValue<std::string>(attribute); + break; + } + } + + std::cout << rang::style::reset << '\n'; + } + + if (saved_symbol_table.exist("embedded")) { + HighFive::Group embedded_data_list = saved_symbol_table.getGroup("embedded"); + for (auto name : embedded_data_list.listObjectNames()) { + std::cout << " "; + std::cout << std::setw(25) << std::setfill('.') // << rang::style::bold; + << std::left << name + ' ' << std::setfill(' '); + std::cout << ' '; + std::cout << embedded_data_list.getGroup(name).getAttribute("type").read<std::string>() << '\n'; + } + } + + const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table"); + + if (saved_symbol_table_has_parent) { + saved_symbol_table = saved_symbol_table.getGroup("symbol table"); + + finished = false; + } + + } while (not finished); + } + std::cout << "-------------------------------------------------------\n"; + for (auto path : std::array{"resuming_checkpoint", "last_checkpoint"}) { + std::cout << rang::fgB::yellow << " * " << rang::fg::reset << rang::style::bold << path << rang::style::reset + << " -> " << rang::fgB::green << file.getGroup(path).getAttribute("name").read<std::string>() + << rang::style::reset << '\n'; + } + } + catch (HighFive::Exception& e) { + std::cerr << rang::fgB::red << "error: " << rang::fg::reset << rang::style::bold << e.what() << rang::style::reset + << '\n'; + } + } +} + +#else // PUGS_HAS_HDF5 + +void +printCheckpointInfo(const std::string&) +{ + std::cerr << rang::fgB::red << "error: " << rang::fg::reset << "checkpoint info requires HDF5\n"; +} + +#endif // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/PrintCheckpointInfo.hpp b/src/utils/checkpointing/PrintCheckpointInfo.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bbe5b4860b18e50df524b26671704e852ccce7c0 --- /dev/null +++ b/src/utils/checkpointing/PrintCheckpointInfo.hpp @@ -0,0 +1,8 @@ +#ifndef PRINT_CHECKPOINT_INFO_HPP +#define PRINT_CHECKPOINT_INFO_HPP + +#include <string> + +void printCheckpointInfo(const std::string& filename); + +#endif // PRINT_CHECKPOINT_INFO_HPP diff --git a/src/utils/checkpointing/RefItemListHFType.hpp b/src/utils/checkpointing/RefItemListHFType.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d5bc94c274b2f89ad8543608e1cce4669223a392 --- /dev/null +++ b/src/utils/checkpointing/RefItemListHFType.hpp @@ -0,0 +1,17 @@ +#ifndef REF_ITEM_LIST_HF_TYPE_HPP +#define REF_ITEM_LIST_HF_TYPE_HPP + +#include <mesh/RefItemList.hpp> +#include <utils/checkpointing/CheckpointUtils.hpp> + +HighFive::EnumType<RefItemListBase::Type> PUGS_INLINE +create_enum_ref_item_list_type() +{ + return {{"boundary", RefItemListBase::Type::boundary}, + {"interface", RefItemListBase::Type::interface}, + {"set", RefItemListBase::Type::set}, + {"undefined", RefItemListBase::Type::undefined}}; +} +HIGHFIVE_REGISTER_TYPE(RefItemListBase::Type, create_enum_ref_item_list_type); + +#endif // REF_ITEM_LIST_HF_TYPE_HPP diff --git a/src/utils/checkpointing/Resume.cpp b/src/utils/checkpointing/Resume.cpp index 1fb9355aad63cc7b0a9ef7277d667be8b4342da3..e34c621782c741a98825b5ee7c63d43e117abf93 100644 --- a/src/utils/checkpointing/Resume.cpp +++ b/src/utils/checkpointing/Resume.cpp @@ -14,145 +14,93 @@ #include <language/utils/ASTCheckpointsInfo.hpp> #include <utils/Exceptions.hpp> -#include <utils/checkpointing/ResumingManager.hpp> #ifdef PUGS_HAS_HDF5 +#include <mesh/Connectivity.hpp> +#include <utils/RandomEngine.hpp> +#include <utils/checkpointing/ResumeUtils.hpp> +#include <utils/checkpointing/ResumingData.hpp> +#include <utils/checkpointing/ResumingManager.hpp> + +#include <language/utils/CheckpointResumeRepository.hpp> + +#include <map> + void resume() { - HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); + try { + ResumingData::create(); + HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly); - HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); + HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); - HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table"); + HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table"); - const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode(); - auto p_symbol_table = p_node->m_symbol_table; + const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode(); + auto p_symbol_table = p_node->m_symbol_table; - ResumingManager& resuming_manager = ResumingManager::getInstance(); + ResumingManager& resuming_manager = ResumingManager::getInstance(); - resuming_manager.checkpointNumber() = file.getAttribute("checkpoint_number").read<uint64_t>(); + resuming_manager.checkpointNumber() = file.getAttribute("checkpoint_number").read<uint64_t>() + 1; - std::cout << " * " - << "Using " << rang::fgB::green << "checkpoint" << rang::fg::reset << " number " - << resuming_manager.checkpointNumber()++ << '\n'; + std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line " + << rang::fgB::yellow << p_node->begin().line << rang::fg::reset << " [using " << rang::fgB::cyan + << checkpoint.getAttribute("name").read<std::string>() << rang::fg::reset << "]\n"; - std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line " << rang::fgB::yellow - << p_node->begin().line << rang::fg::reset << " [checkpoint id " << rang::fgB::cyan - << resuming_manager.checkpointId() << rang::fg::reset << "]\n"; + { + HighFive::Group random_seed = checkpoint.getGroup("singleton/random_seed"); + RandomEngine::instance().setRandomSeed(random_seed.getAttribute("current_seed").read<uint64_t>()); + } - bool finished = true; - do { - finished = true; + { + std::cout << rang::fgB::magenta << "Resume DualConnectivityManager NIY" << rang::fg::reset << '\n'; + std::cout << rang::fgB::magenta << "Resume DualMeshManager NIY" << rang::fg::reset << '\n'; + } - for (auto symbol_name : saved_symbol_table.listAttributeNames()) { - auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin()); - auto& attribute = p_symbol->attributes(); - switch (attribute.dataType()) { - case ASTNodeDataType::bool_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>(); - break; - } - case ASTNodeDataType::unsigned_int_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>(); - break; - } - case ASTNodeDataType::int_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>(); - break; - } - case ASTNodeDataType::double_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>(); - break; - } - case ASTNodeDataType::string_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>(); - break; - } - case ASTNodeDataType::vector_t: { - switch (attribute.dataType().dimension()) { - case 1: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>(); - break; - } - case 2: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>(); - break; - } - case 3: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>(); - break; - } - // LCOV_EXCL_START - default: { - throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); - } - // LCOV_EXCL_STOP - } - break; - } - case ASTNodeDataType::matrix_t: { - // LCOV_EXCL_START - if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) { - throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); - } - // LCOV_EXCL_STOP - switch (attribute.dataType().numberOfRows()) { - case 1: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>(); - break; - } - case 2: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>(); - break; - } - case 3: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>(); - break; - } - // LCOV_EXCL_START - default: { - throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); - } - // LCOV_EXCL_STOP - } - break; - } - case ASTNodeDataType::tuple_t: { - switch (attribute.dataType().contentType()) { + ResumingData::instance().readData(checkpoint); + + bool finished = true; + do { + finished = true; + + for (auto symbol_name : saved_symbol_table.listAttributeNames()) { + auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin()); + auto& attribute = p_symbol->attributes(); + switch (attribute.dataType()) { case ASTNodeDataType::bool_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>(); break; } case ASTNodeDataType::unsigned_int_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>(); break; } case ASTNodeDataType::int_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>(); break; } case ASTNodeDataType::double_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>(); break; } case ASTNodeDataType::string_t: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>(); break; } case ASTNodeDataType::vector_t: { - switch (attribute.dataType().contentType().dimension()) { + switch (attribute.dataType().dimension()) { case 1: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>(); break; } case 2: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>(); break; } case 3: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>(); break; } // LCOV_EXCL_START @@ -165,57 +113,160 @@ resume() } case ASTNodeDataType::matrix_t: { // LCOV_EXCL_START - if (attribute.dataType().contentType().numberOfRows() != - attribute.dataType().contentType().numberOfColumns()) { - throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); } // LCOV_EXCL_STOP - switch (attribute.dataType().contentType().numberOfRows()) { + switch (attribute.dataType().numberOfRows()) { case 1: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>(); break; } case 2: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>(); break; } case 3: { - attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>(); + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>(); break; } // LCOV_EXCL_START default: { - throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension"); } // LCOV_EXCL_STOP } break; } + case ASTNodeDataType::tuple_t: { + switch (attribute.dataType().contentType()) { + case ASTNodeDataType::bool_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>(); + break; + } + case ASTNodeDataType::unsigned_int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>(); + break; + } + case ASTNodeDataType::int_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>(); + break; + } + case ASTNodeDataType::double_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>(); + break; + } + case ASTNodeDataType::string_t: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>(); + break; + } + case ASTNodeDataType::vector_t: { + switch (attribute.dataType().contentType().dimension()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + case ASTNodeDataType::matrix_t: { + // LCOV_EXCL_START + if (attribute.dataType().contentType().numberOfRows() != + attribute.dataType().contentType().numberOfColumns()) { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + switch (attribute.dataType().contentType().numberOfRows()) { + case 1: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>(); + break; + } + case 2: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>(); + break; + } + case 3: { + attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>(); + break; + } + // LCOV_EXCL_START + default: { + throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension"); + } + // LCOV_EXCL_STOP + } + break; + } + default: { + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType())); + } + } + break; + } default: { - throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType())); + throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType())); } } - break; - } - default: { - throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType())); } + + if (saved_symbol_table.exist("embedded")) { + HighFive::Group embedded = saved_symbol_table.getGroup("embedded"); + + for (auto symbol_name : embedded.listObjectNames()) { + auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin()); + if (p_symbol->attributes().dataType() == ASTNodeDataType::tuple_t) { + HighFive::Group embedded_tuple_group = embedded.getGroup(symbol_name); + const size_t number_of_components = embedded_tuple_group.getNumberObjects(); + std::vector<EmbeddedData> embedded_tuple(number_of_components); + + for (size_t i_component = 0; i_component < number_of_components; ++i_component) { + embedded_tuple[i_component] = + CheckpointResumeRepository::instance().resume(file, checkpoint, + p_symbol->attributes().dataType().contentType(), + p_symbol->name() + "/" + std::to_string(i_component), + saved_symbol_table); + } + p_symbol->attributes().value() = embedded_tuple; + } else { + p_symbol->attributes().value() = + CheckpointResumeRepository::instance().resume(file, checkpoint, p_symbol->attributes().dataType(), + p_symbol->name(), saved_symbol_table); + } + } } - } - const bool symbol_table_has_parent = p_symbol_table->hasParentTable(); - const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table"); + const bool symbol_table_has_parent = p_symbol_table->hasParentTable(); + const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table"); - Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent)); + Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent)); - if (symbol_table_has_parent and saved_symbol_table_has_parent) { - p_symbol_table = p_symbol_table->parentTable(); - saved_symbol_table = saved_symbol_table.getGroup("symbol table"); + if (symbol_table_has_parent and saved_symbol_table_has_parent) { + p_symbol_table = p_symbol_table->parentTable(); + saved_symbol_table = saved_symbol_table.getGroup("symbol table"); - finished = false; - } + finished = false; + } + + } while (not finished); - } while (not finished); + ResumingData::destroy(); + } + catch (HighFive::Exception& e) { + throw NormalError(e.what()); + } } #else // PUGS_HAS_HDF5 diff --git a/src/utils/checkpointing/ResumeUtils.cpp b/src/utils/checkpointing/ResumeUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..91a6e9faed6d506a95bc1eb870eaabc95710b38c --- /dev/null +++ b/src/utils/checkpointing/ResumeUtils.cpp @@ -0,0 +1,18 @@ +#include <utils/checkpointing/ResumeUtils.hpp> + +#include <language/utils/DataHandler.hpp> +#include <language/utils/SymbolTable.hpp> +#include <utils/checkpointing/ResumingData.hpp> + +EmbeddedData +readMesh([[maybe_unused]] const HighFive::File& file, + [[maybe_unused]] const HighFive::Group& checkpoint_group, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group) +{ + const HighFive::Group mesh_group = symbol_table_group.getGroup("embedded/" + symbol_name); + + const size_t mesh_id = mesh_group.getAttribute("id").read<uint64_t>(); + + return {std::make_shared<DataHandler<const MeshVariant>>(ResumingData::instance().meshVariant(mesh_id))}; +} diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..64e978449b5dfaa33642f9393a536271f2b4968a --- /dev/null +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -0,0 +1,36 @@ +#ifndef RESUME_UTILS_HPP +#define RESUME_UTILS_HPP + +#include <utils/HighFivePugsUtils.hpp> + +#include <language/utils/SymbolTable.hpp> +#include <mesh/CellType.hpp> +#include <mesh/ItemValue.hpp> + +template <typename DataType> +PUGS_INLINE Array<DataType> +read(const HighFive::Group& group, const std::string& name) +{ + using data_type = std::remove_const_t<DataType>; + + auto dataset = group.getDataSet(name); + Array<DataType> array(dataset.getElementCount()); + if constexpr (std::is_same_v<CellType, data_type>) { + dataset.template read<short>(reinterpret_cast<short*>(&(array[0]))); + } else if constexpr ((std::is_same_v<CellId, data_type>) or (std::is_same_v<FaceId, data_type>) or + (std::is_same_v<EdgeId, data_type>) or (std::is_same_v<NodeId, data_type>)) { + using base_type = typename data_type::base_type; + dataset.template read<base_type>(reinterpret_cast<base_type*>(&(array[0]))); + } else { + dataset.template read<data_type>(&(array[0])); + } + + return array; +} + +EmbeddedData readMesh(const HighFive::File& file, + const HighFive::Group& checkpoint_group, + const std::string& symbol_name, + const HighFive::Group& symbol_table_group); + +#endif // RESUME_UTILS_HPP diff --git a/src/utils/checkpointing/ResumingData.cpp b/src/utils/checkpointing/ResumingData.cpp new file mode 100644 index 0000000000000000000000000000000000000000..39b88a6d892582527fdfd8a3280200dc0f8a2fc5 --- /dev/null +++ b/src/utils/checkpointing/ResumingData.cpp @@ -0,0 +1,275 @@ +#include <utils/checkpointing/ResumingData.hpp> + +#include <mesh/ConnectivityDescriptor.hpp> +#include <mesh/Mesh.hpp> +#include <utils/Exceptions.hpp> +#include <utils/checkpointing/RefItemListHFType.hpp> +#include <utils/checkpointing/ResumeUtils.hpp> + +ResumingData* ResumingData::m_instance = nullptr; + +void +ResumingData::_getConnectivityList(const HighFive::Group& checkpoint) +{ + if (checkpoint.exist("connectivity")) { + HighFive::Group connectivity_group = checkpoint.getGroup("connectivity"); + + std::map<size_t, std::string> id_name_map; + for (auto connectivity_id_name : connectivity_group.listObjectNames()) { + HighFive::Group connectivity_info = connectivity_group.getGroup(connectivity_id_name); + id_name_map[connectivity_info.getAttribute("id").read<uint64_t>()] = connectivity_id_name; + } + + for (auto [id, name] : id_name_map) { + HighFive::Group connectivity_data = connectivity_group.getGroup(name); + const uint64_t dimension = connectivity_data.getAttribute("dimension").read<uint64_t>(); + const std::string type = connectivity_data.getAttribute("type").read<std::string>(); + + if (type != "unstructured") { + throw UnexpectedError("invalid connectivity type: " + type); + } + + ConnectivityDescriptor descriptor; + descriptor.setCellTypeVector(read<CellType>(connectivity_data, "cell_type")); + descriptor.setCellNumberVector(read<int>(connectivity_data, "cell_numbers")); + descriptor.setNodeNumberVector(read<int>(connectivity_data, "node_numbers")); + + descriptor.setCellOwnerVector(read<int>(connectivity_data, "cell_owner")); + descriptor.setNodeOwnerVector(read<int>(connectivity_data, "node_owner")); + + using index_type = typename ConnectivityMatrix::IndexType; + + descriptor.setCellToNodeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "cell_to_node_matrix_rowsMap"), + read<index_type>(connectivity_data, "cell_to_node_matrix_values")}); + + if (dimension > 1) { + descriptor.setFaceNumberVector(read<int>(connectivity_data, "face_numbers")); + descriptor.setFaceOwnerVector(read<int>(connectivity_data, "face_owner")); + + descriptor.setCellToFaceMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "cell_to_face_matrix_rowsMap"), + read<index_type>(connectivity_data, "cell_to_face_matrix_values")}); + + descriptor.setFaceToNodeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "face_to_node_matrix_rowsMap"), + read<index_type>(connectivity_data, "face_to_node_matrix_values")}); + + descriptor.setNodeToFaceMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "node_to_face_matrix_rowsMap"), + read<index_type>(connectivity_data, "node_to_face_matrix_values")}); + + descriptor.setCellFaceIsReversed(read<bool>(connectivity_data, "cell_face_is_reversed")); + } + + if (dimension > 2) { + descriptor.setEdgeNumberVector(read<int>(connectivity_data, "edge_numbers")); + descriptor.setEdgeOwnerVector(read<int>(connectivity_data, "edge_owner")); + + descriptor.setCellToEdgeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "cell_to_edge_matrix_rowsMap"), + read<index_type>(connectivity_data, "cell_to_edge_matrix_values")}); + + descriptor.setFaceToEdgeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "face_to_edge_matrix_rowsMap"), + read<index_type>(connectivity_data, "face_to_edge_matrix_values")}); + + descriptor.setEdgeToNodeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "edge_to_node_matrix_rowsMap"), + read<index_type>(connectivity_data, "edge_to_node_matrix_values")}); + + descriptor.setNodeToEdgeMatrix( + ConnectivityMatrix{read<index_type>(connectivity_data, "node_to_edge_matrix_rowsMap"), + read<index_type>(connectivity_data, "node_to_edge_matrix_values")}); + + descriptor.setFaceEdgeIsReversed(read<bool>(connectivity_data, "face_edge_is_reversed")); + } + + if (connectivity_data.exist("item_ref_list")) { + HighFive::Group item_group = connectivity_data.getGroup("item_ref_list"); + for (auto item_type_name : item_group.listObjectNames()) { + HighFive::Group item_ref_name_list = item_group.getGroup(item_type_name); + + for (auto item_ref_list_name : item_ref_name_list.listObjectNames()) { + HighFive::Group item_ref_list_data = item_ref_name_list.getGroup(item_ref_list_name); + + RefId ref_id(item_ref_list_data.getAttribute("tag_number").read<uint64_t>(), + item_ref_list_data.getAttribute("tag_name").read<std::string>()); + + RefItemListBase::Type ref_item_list_type = + item_ref_list_data.getAttribute("type").read<RefItemListBase::Type>(); + + if (item_type_name == "cell") { + descriptor.addRefItemList( + RefItemList<ItemType::cell>{ref_id, read<CellId>(item_ref_list_data, "list"), ref_item_list_type}); + } else if (item_type_name == "face") { + descriptor.addRefItemList( + RefItemList<ItemType::face>{ref_id, read<FaceId>(item_ref_list_data, "list"), ref_item_list_type}); + } else if (item_type_name == "edge") { + descriptor.addRefItemList( + RefItemList<ItemType::edge>{ref_id, read<EdgeId>(item_ref_list_data, "list"), ref_item_list_type}); + } else if (item_type_name == "node") { + descriptor.addRefItemList( + RefItemList<ItemType::node>{ref_id, read<NodeId>(item_ref_list_data, "list"), ref_item_list_type}); + } else { + throw UnexpectedError("invalid item type: " + item_type_name); + } + } + } + } + + while (id > GlobalVariableManager::instance().getConnectivityId()) { + GlobalVariableManager::instance().getAndIncrementConnectivityId(); + } + + if (m_id_to_iconnectivity_map.contains(id)) { + throw UnexpectedError("connectivity of id " + std::to_string(id) + " already defined!"); + } + + switch (dimension) { + case 1: { + m_id_to_iconnectivity_map.insert({id, Connectivity<1>::build(descriptor)}); + break; + } + case 2: { + m_id_to_iconnectivity_map.insert({id, Connectivity<2>::build(descriptor)}); + break; + } + case 3: { + m_id_to_iconnectivity_map.insert({id, Connectivity<3>::build(descriptor)}); + break; + } + default: { + throw UnexpectedError("invalid dimension " + std::to_string(dimension)); + } + } + } + } + + const size_t next_connectivity_id = + checkpoint.getGroup("singleton/global_variables").getAttribute("connectivity_id").read<size_t>(); + + while (next_connectivity_id > GlobalVariableManager::instance().getConnectivityId()) { + GlobalVariableManager::instance().getAndIncrementConnectivityId(); + } +} + +template <size_t Dimension> +std::shared_ptr<const MeshVariant> +ResumingData::_readPolygonalMesh(const HighFive::Group& mesh_data) +{ + const uint64_t connectivity_id = mesh_data.getAttribute("connectivity").read<uint64_t>(); + auto i_id_to_iconnectivity = m_id_to_iconnectivity_map.find(connectivity_id); + if (i_id_to_iconnectivity == m_id_to_iconnectivity_map.end()) { + throw UnexpectedError("cannot find connectivity " + std::to_string(connectivity_id)); + } + std::shared_ptr<const IConnectivity> i_connectivity = i_id_to_iconnectivity->second; + if (i_connectivity->dimension() != Dimension) { + throw UnexpectedError("invalid connectivity dimension " + std::to_string(i_connectivity->dimension())); + } + + std::shared_ptr<const Connectivity<Dimension>> connectivity = + std::dynamic_pointer_cast<const Connectivity<Dimension>>(i_connectivity); + + Array<const TinyVector<Dimension>> xr_array = read<TinyVector<Dimension>>(mesh_data, "xr"); + NodeValue<const TinyVector<Dimension>> xr{*connectivity, xr_array}; + + std::shared_ptr mesh = std::make_shared<const Mesh<Dimension>>(connectivity, xr); + return std::make_shared<const MeshVariant>(mesh); +} + +void +ResumingData::_getMeshVariantList(const HighFive::Group& checkpoint) +{ + if (checkpoint.exist("mesh")) { + HighFive::Group mesh_group = checkpoint.getGroup("mesh"); + + std::map<size_t, std::string> id_name_map; + for (auto mesh_id_name : mesh_group.listObjectNames()) { + HighFive::Group mesh_info = mesh_group.getGroup(mesh_id_name); + id_name_map[mesh_info.getAttribute("id").read<uint64_t>()] = mesh_id_name; + } + + for (auto [id, name] : id_name_map) { + HighFive::Group mesh_data = mesh_group.getGroup(name); + const std::string type = mesh_data.getAttribute("type").read<std::string>(); + const uint64_t dimension = mesh_data.getAttribute("dimension").read<uint64_t>(); + + if (type != "polygonal") { + throw UnexpectedError("invalid connectivity type"); + } + + while (id > GlobalVariableManager::instance().getMeshId()) { + GlobalVariableManager::instance().getAndIncrementMeshId(); + } + + switch (dimension) { + case 1: { + m_id_to_mesh_variant_map.insert({id, this->_readPolygonalMesh<1>(mesh_data)}); + break; + } + case 2: { + m_id_to_mesh_variant_map.insert({id, this->_readPolygonalMesh<2>(mesh_data)}); + break; + } + case 3: { + m_id_to_mesh_variant_map.insert({id, this->_readPolygonalMesh<3>(mesh_data)}); + break; + } + default: { + throw UnexpectedError("invalid mesh dimension " + std::to_string(dimension)); + } + } + } + } + + const size_t next_mesh_id = checkpoint.getGroup("singleton/global_variables").getAttribute("mesh_id").read<size_t>(); + + while (next_mesh_id > GlobalVariableManager::instance().getMeshId()) { + GlobalVariableManager::instance().getAndIncrementMeshId(); + } +} + +void +ResumingData::readData(const HighFive::Group& checkpoint) +{ + this->_getConnectivityList(checkpoint); + this->_getMeshVariantList(checkpoint); +} + +const std::shared_ptr<const IConnectivity>& +ResumingData::iConnectivity(const size_t connectivity_id) const +{ + auto i_id_to_connectivity = m_id_to_iconnectivity_map.find(connectivity_id); + if (i_id_to_connectivity == m_id_to_iconnectivity_map.end()) { + throw UnexpectedError("cannot find connectivity of id " + std::to_string(connectivity_id)); + } else { + return i_id_to_connectivity->second; + } +} + +const std::shared_ptr<const MeshVariant>& +ResumingData::meshVariant(const size_t mesh_id) const +{ + auto i_id_to_connectivity = m_id_to_mesh_variant_map.find(mesh_id); + if (i_id_to_connectivity == m_id_to_mesh_variant_map.end()) { + throw UnexpectedError("cannot find connectivity of id " + std::to_string(mesh_id)); + } else { + return i_id_to_connectivity->second; + } +} + +void +ResumingData::create() +{ + Assert(m_instance == nullptr, "instance already created"); + m_instance = new ResumingData; +} + +void +ResumingData::destroy() +{ + Assert(m_instance != nullptr, "instance not created"); + delete m_instance; + m_instance = nullptr; +} diff --git a/src/utils/checkpointing/ResumingData.hpp b/src/utils/checkpointing/ResumingData.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b7a25221e926ad2cb97ef8b61807c140ba5d8d5d --- /dev/null +++ b/src/utils/checkpointing/ResumingData.hpp @@ -0,0 +1,49 @@ +#ifndef RESUMING_DATA_HPP +#define RESUMING_DATA_HPP + +#include <utils/HighFivePugsUtils.hpp> +#include <utils/PugsAssert.hpp> + +#include <map> + +class IConnectivity; +class MeshVariant; + +class ResumingData +{ + private: + std::map<size_t, std::shared_ptr<const IConnectivity>> m_id_to_iconnectivity_map; + std::map<size_t, std::shared_ptr<const MeshVariant>> m_id_to_mesh_variant_map; + + ResumingData() = default; + ~ResumingData() = default; + + static ResumingData* m_instance; + + template <size_t Dimension> + std::shared_ptr<const MeshVariant> _readPolygonalMesh(const HighFive::Group& mesh_group); + + void _getConnectivityList(const HighFive::Group& checkpoint); + void _getMeshVariantList(const HighFive::Group& checkpoint); + + public: + void readData(const HighFive::Group& checkpoint); + + const std::shared_ptr<const IConnectivity>& iConnectivity(const size_t connectivity_id) const; + const std::shared_ptr<const MeshVariant>& meshVariant(const size_t mesh_id) const; + + static void create(); + static void destroy(); + + static ResumingData& + instance() + { + Assert(m_instance != nullptr, "instance not created"); + return *m_instance; + } + + ResumingData(const ResumingData&) = delete; + ResumingData(ResumingData&&) = delete; +}; + +#endif // RESUMING_DATA_HPP diff --git a/src/utils/checkpointing/ResumingManager.cpp b/src/utils/checkpointing/ResumingManager.cpp index a2937d3a9bbe3f39b278486b54f28175fe0050bb..0dfe83fd93bb126684193c15426dbac553762f4d 100644 --- a/src/utils/checkpointing/ResumingManager.cpp +++ b/src/utils/checkpointing/ResumingManager.cpp @@ -35,7 +35,7 @@ ResumingManager::checkpointId() HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint"); - m_checkpoint_id = std::make_unique<size_t>(checkpoint.getAttribute("checkpoint_id").read<uint64_t>()); + m_checkpoint_id = std::make_unique<size_t>(checkpoint.getAttribute("id").read<uint64_t>()); } #else // PUGS_HAS_HDF5 m_checkpoint_id = std::make_unique<uint64_t>(0); diff --git a/tests/test_BuiltinFunctionEmbedder.cpp b/tests/test_BuiltinFunctionEmbedder.cpp index be28e7d75a42240f2f2e248e8cade2261e5ef5e5..74128618ee1af9ce618a0412f0a000c36d39e8ca 100644 --- a/tests/test_BuiltinFunctionEmbedder.cpp +++ b/tests/test_BuiltinFunctionEmbedder.cpp @@ -13,6 +13,11 @@ template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const uint64_t>> = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("shared_const_uint64_t"); +template <> +inline ASTNodeDataType ast_node_data_type_from<int> = ASTNodeDataType{}; +template <> +inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const int>> = ASTNodeDataType{}; + TEST_CASE("BuiltinFunctionEmbedder", "[language]") { rang::setControlMode(rang::control::Off);