diff --git a/src/language/ast/ASTModulesImporter.cpp b/src/language/ast/ASTModulesImporter.cpp index eb03cf56ad57eb53395a47f34f66dc0b4e9782ab..7b9670160f035a09e0bc2fb023834ac70bcaf36e 100644 --- a/src/language/ast/ASTModulesImporter.cpp +++ b/src/language/ast/ASTModulesImporter.cpp @@ -46,7 +46,7 @@ ASTModulesImporter::ASTModulesImporter(ASTNode& root_node) : m_symbol_table{*roo { Assert(root_node.is_root()); OperatorRepository::instance().reset(); - m_module_repository.populateMandatorySymbolTable(root_node, m_symbol_table); + m_module_repository.populateMandatoryData(root_node, m_symbol_table); this->_importAllModules(root_node); diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp index 05643f6566aca6728b9cffb8bdd497322c12d313..a21ad4130bbdff345d0b40bf67fa799be80e66c1 100644 --- a/src/language/modules/CoreModule.cpp +++ b/src/language/modules/CoreModule.cpp @@ -19,6 +19,7 @@ #include <language/utils/BinaryOperatorRegisterForString.hpp> #include <language/utils/BinaryOperatorRegisterForZ.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> #include <language/utils/Exit.hpp> #include <language/utils/IncDecOperatorRegisterForN.hpp> #include <language/utils/IncDecOperatorRegisterForZ.hpp> @@ -34,7 +35,9 @@ #include <utils/PugsUtils.hpp> #include <utils/RandomEngine.hpp> #include <utils/checkpointing/Checkpoint.hpp> +#include <utils/checkpointing/CheckpointUtils.hpp> #include <utils/checkpointing/Resume.hpp> +#include <utils/checkpointing/ResumeUtils.hpp> #include <utils/checkpointing/ResumingManager.hpp> #include <random> @@ -195,4 +198,14 @@ CoreModule::registerOperators() const void CoreModule::registerCheckpointResume() const -{} +{ + CheckpointResumeRepository::instance() + .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const OStream>>, + std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data, + HighFive::File& file, HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group) { + writeOStream(symbol_name, embedded_data, file, checkpoint_group, symbol_table_group); + }), + std::function([](const std::string& symbol_name, const HighFive::Group& symbol_table_group) + -> EmbeddedData { return readOStream(symbol_name, symbol_table_group); })); +} diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index e291ca761ae2bd2721110f645bb7b93bc395b7b8..53404f533d293d3dad153f839cb1e999f68ced88 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -150,7 +150,7 @@ ModuleRepository::populateSymbolTable(const ASTNode& module_name_node, SymbolTab } void -ModuleRepository::populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table) +ModuleRepository::populateMandatoryData(const ASTNode& root_node, SymbolTable& symbol_table) { for (auto&& [module_name, i_module] : m_module_set) { if (i_module->isMandatory()) { @@ -164,6 +164,8 @@ ModuleRepository::populateMandatorySymbolTable(const ASTNode& root_node, SymbolT this->_populateSymbolTable(root_node, module_name, i_module->getNameValueMap(), symbol_table); + i_module->registerCheckpointResume(); + for (const auto& [symbol_name, embedded] : i_module->getNameTypeMap()) { BasicAffectationRegisterFor<EmbeddedData>(ASTNodeDataType::build<ASTNodeDataType::type_id_t>(symbol_name)); } diff --git a/src/language/modules/ModuleRepository.hpp b/src/language/modules/ModuleRepository.hpp index 3e3d178c37b9b24977599d1a80d1e6f794fb5a96..407d7b55bd9b7c8454159a315b9c2f7a256eecb6 100644 --- a/src/language/modules/ModuleRepository.hpp +++ b/src/language/modules/ModuleRepository.hpp @@ -33,7 +33,7 @@ class ModuleRepository public: void populateSymbolTable(const ASTNode& module_name_node, SymbolTable& symbol_table); - void populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table); + void populateMandatoryData(const ASTNode& root_node, SymbolTable& symbol_table); void registerOperators(const std::string& module_name); void registerCheckpointResume(const std::string& module_name); diff --git a/src/language/utils/OFStream.cpp b/src/language/utils/OFStream.cpp index ca7c77fad8a69fd79ae2b512635acbb7f03277fd..59ebeb7e2544bf41fde37388573fd3df58d5e17c 100644 --- a/src/language/utils/OFStream.cpp +++ b/src/language/utils/OFStream.cpp @@ -3,11 +3,16 @@ #include <utils/Filesystem.hpp> #include <utils/Messenger.hpp> -OFStream::OFStream(const std::string& filename) +OFStream::OFStream(const std::string& filename, bool append) + : OStream(OStream::Type::std_ofstream), m_filename{filename} { if (parallel::rank() == 0) { createDirectoryIfNeeded(filename); - m_fstream.open(filename); + if (append) { + m_fstream.open(filename, std::ios_base::app); + } else { + m_fstream.open(filename); + } if (m_fstream.is_open()) { m_ostream = &m_fstream; } else { diff --git a/src/language/utils/OFStream.hpp b/src/language/utils/OFStream.hpp index 8421906fca8746192b0d24730f7a731fe2640894..ebd9f5b4c2c09aa52c777ad1d758ec8b11351e20 100644 --- a/src/language/utils/OFStream.hpp +++ b/src/language/utils/OFStream.hpp @@ -8,10 +8,17 @@ class OFStream final : public OStream { private: + std::string m_filename; std::ofstream m_fstream; public: - OFStream(const std::string& filename); + const std::string& + filename() const + { + return m_filename; + } + + OFStream(const std::string& filename, bool append = false); OFStream() = delete; ~OFStream() = default; diff --git a/src/language/utils/OStream.hpp b/src/language/utils/OStream.hpp index 24b1e62dd0a2232829ab896af036a1bdcc95edfb..9d79db82df64471f162f5ed27e5ae779c444b7d9 100644 --- a/src/language/utils/OStream.hpp +++ b/src/language/utils/OStream.hpp @@ -8,10 +8,24 @@ class OStream { + public: + enum class Type + { + std_ostream, + std_ofstream + }; + protected: mutable std::ostream* m_ostream = nullptr; + Type m_type; public: + const Type& + type() const + { + return m_type; + } + template <typename DataT> friend std::shared_ptr<const OStream> operator<<(const std::shared_ptr<const OStream>& os, const DataT& t) @@ -34,14 +48,12 @@ class OStream return os; } - OStream(std::ostream& os) : m_ostream(&os) {} - OStream() = default; + OStream(std::ostream& os, Type type = Type::std_ostream) : m_ostream(&os), m_type{type} {} + OStream(Type type) : m_type{type} {} virtual ~OStream() = default; }; -void checkpointStore(const OStream&); - template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const OStream>> = ASTNodeDataType::build<ASTNodeDataType::type_id_t>("ostream"); diff --git a/src/utils/checkpointing/CheckpointUtils.cpp b/src/utils/checkpointing/CheckpointUtils.cpp index 86aad96ae87961b4934972f35b348c9924584eb3..afb542d7a9477d876751566396eb69b1e6b0839d 100644 --- a/src/utils/checkpointing/CheckpointUtils.cpp +++ b/src/utils/checkpointing/CheckpointUtils.cpp @@ -5,6 +5,8 @@ #include <language/modules/SchemeModuleTypes.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/DataHandler.hpp> +#include <language/utils/OFStream.hpp> +#include <language/utils/OStream.hpp> #include <mesh/ItemType.hpp> #include <mesh/Mesh.hpp> #include <mesh/MeshVariant.hpp> @@ -23,6 +25,7 @@ #include <utils/checkpointing/IInterfaceDescriptorHFType.hpp> #include <utils/checkpointing/IZoneDescriptorHFType.hpp> #include <utils/checkpointing/ItemTypeHFType.hpp> +#include <utils/checkpointing/OStreamTypeHFType.hpp> #include <utils/checkpointing/QuadratureTypeHFType.hpp> #include <utils/checkpointing/RefItemListHFType.hpp> @@ -459,3 +462,32 @@ writeMesh(const std::string& symbol_name, }, mesh_v->variant()); } + +void +writeOStream(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File&, + HighFive::Group&, + HighFive::Group& symbol_table_group) +{ + HighFive::Group variable_group = symbol_table_group.createGroup("embedded/" + symbol_name); + + std::shared_ptr<const OStream> ostream_p = + dynamic_cast<const DataHandler<const OStream>&>(embedded_data.get()).data_ptr(); + + const OStream& ostream = *ostream_p; + + variable_group.createAttribute("type", dataTypeName(ast_node_data_type_from<decltype(ostream_p)>)); + variable_group.createAttribute("ostream_type", ostream.type()); + + switch (ostream.type()) { + case OStream::Type::std_ofstream: { + const OFStream& ofstream = dynamic_cast<const OFStream&>(ostream); + variable_group.createAttribute("filename", ofstream.filename()); + break; + } + case OStream::Type::std_ostream: { + throw NotImplementedError("std::ostream checkpoint"); + } + } +} diff --git a/src/utils/checkpointing/CheckpointUtils.hpp b/src/utils/checkpointing/CheckpointUtils.hpp index 77de3fc3298f86fb9caee8518e5539e3637484a8..0870184203c36bce615818d138cf29ff3e215ed3 100644 --- a/src/utils/checkpointing/CheckpointUtils.hpp +++ b/src/utils/checkpointing/CheckpointUtils.hpp @@ -90,4 +90,10 @@ void writeMesh(const std::string& symbol_name, HighFive::Group& checkpoint_group, HighFive::Group& symbol_table_group); +void writeOStream(const std::string& symbol_name, + const EmbeddedData& embedded_data, + HighFive::File& file, + HighFive::Group& checkpoint_group, + HighFive::Group& symbol_table_group); + #endif // CHECKPOINT_UTILS_HPP diff --git a/src/utils/checkpointing/OStreamTypeHFType.hpp b/src/utils/checkpointing/OStreamTypeHFType.hpp new file mode 100644 index 0000000000000000000000000000000000000000..59bf5ddd06033b0e23153937549c3582928feb24 --- /dev/null +++ b/src/utils/checkpointing/OStreamTypeHFType.hpp @@ -0,0 +1,14 @@ +#ifndef OSTREAM_TYPE_HF_TYPE_HPP +#define OSTREAM_TYPE_HF_TYPE_HPP + +#include <language/utils/OStream.hpp> +#include <utils/checkpointing/CheckpointUtils.hpp> + +HighFive::EnumType<OStream::Type> PUGS_INLINE +create_enum_ostream_type() +{ + return {{"std_ostream", OStream::Type::std_ostream}, {"std_ofstream", OStream::Type::std_ofstream}}; +} +HIGHFIVE_REGISTER_TYPE(OStream::Type, create_enum_ostream_type); + +#endif // OSTREAM_TYPE_HF_TYPE_HPP diff --git a/src/utils/checkpointing/ResumeUtils.cpp b/src/utils/checkpointing/ResumeUtils.cpp index 6f76e8cc09d5248cf5d516420b1d5a0aef29ab19..125fe1f9c753f8038e013a12ab334d560d24a147 100644 --- a/src/utils/checkpointing/ResumeUtils.cpp +++ b/src/utils/checkpointing/ResumeUtils.cpp @@ -4,6 +4,7 @@ #include <analysis/GaussLobattoQuadratureDescriptor.hpp> #include <analysis/GaussQuadratureDescriptor.hpp> #include <language/utils/DataHandler.hpp> +#include <language/utils/OFStream.hpp> #include <language/utils/SymbolTable.hpp> #include <mesh/NamedBoundaryDescriptor.hpp> #include <mesh/NamedInterfaceDescriptor.hpp> @@ -20,6 +21,7 @@ #include <utils/checkpointing/IInterfaceDescriptorHFType.hpp> #include <utils/checkpointing/IZoneDescriptorHFType.hpp> #include <utils/checkpointing/ItemTypeHFType.hpp> +#include <utils/checkpointing/OStreamTypeHFType.hpp> #include <utils/checkpointing/QuadratureTypeHFType.hpp> #include <utils/checkpointing/ResumingData.hpp> @@ -264,5 +266,29 @@ readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_gro const size_t mesh_id = mesh_group.getAttribute("id").read<uint64_t>(); - return EmbeddedData{std::make_shared<DataHandler<const MeshVariant>>(ResumingData::instance().meshVariant(mesh_id))}; + return {std::make_shared<DataHandler<const MeshVariant>>(ResumingData::instance().meshVariant(mesh_id))}; +} + +EmbeddedData +readOStream(const std::string& symbol_name, const HighFive::Group& symbol_table_group) +{ + const HighFive::Group ostream_group = symbol_table_group.getGroup("embedded/" + symbol_name); + + const OStream::Type ostream_type = ostream_group.getAttribute("ostream_type").read<OStream::Type>(); + + std::shared_ptr<const OStream> p_ostream; + + switch (ostream_type) { + case OStream::Type::std_ofstream: { + std::string filename = ostream_group.getAttribute("filename").read<std::string>(); + + p_ostream = std::make_shared<OFStream>(filename, true); + break; + } + case OStream::Type::std_ostream: { + throw NotImplementedError("std::ostream resume"); + } + } + + return {std::make_shared<DataHandler<const OStream>>(p_ostream)}; } diff --git a/src/utils/checkpointing/ResumeUtils.hpp b/src/utils/checkpointing/ResumeUtils.hpp index 2419cfb5bc69cccd1d10778b0f7937f8eb06dadc..c66f48207eaf073f5b5ab09f66c5715837f0de82 100644 --- a/src/utils/checkpointing/ResumeUtils.hpp +++ b/src/utils/checkpointing/ResumeUtils.hpp @@ -37,5 +37,6 @@ EmbeddedData readIQuadratureDescriptor(const std::string& symbol_name, const Hig EmbeddedData readItemType(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readIZoneDescriptor(const std::string& symbol_name, const HighFive::Group& symbol_table_group); EmbeddedData readMesh(const std::string& symbol_name, const HighFive::Group& symbol_table_group); +EmbeddedData readOStream(const std::string& symbol_name, const HighFive::Group& symbol_table_group); #endif // RESUME_UTILS_HPP diff --git a/tests/test_BinaryExpressionProcessor_shift.cpp b/tests/test_BinaryExpressionProcessor_shift.cpp index b324735f26d5892c2164f393012abe930c8b06b1..1766e111ea12ef11f0d05f2d0d46db5270dcd4e6 100644 --- a/tests/test_BinaryExpressionProcessor_shift.cpp +++ b/tests/test_BinaryExpressionProcessor_shift.cpp @@ -53,6 +53,7 @@ fout << createSocketServer(0) << "\n";)"; TAO_PEGTL_NAMESPACE::string_input input{data.str(), "test.pgs"}; auto ast = ASTBuilder::build(input); + CheckpointResumeRepository::create(); ASTModulesImporter{*ast}; ASTNodeTypeCleaner<language::import_instruction>{*ast}; @@ -67,6 +68,7 @@ fout << createSocketServer(0) << "\n";)"; ASTExecutionStack::create(); ast->execute(exec_policy); ASTExecutionStack::destroy(); + CheckpointResumeRepository::destroy(); } REQUIRE(std::filesystem::exists(filename)); diff --git a/tests/test_BinaryExpressionProcessor_utils.hpp b/tests/test_BinaryExpressionProcessor_utils.hpp index 5a98115c387e885a893119bab7703f48cc4bbb08..459115866b4a54d75044aa88f44a4172ec195e15 100644 --- a/tests/test_BinaryExpressionProcessor_utils.hpp +++ b/tests/test_BinaryExpressionProcessor_utils.hpp @@ -9,6 +9,7 @@ #include <language/ast/ASTNodeExpressionBuilder.hpp> #include <language/ast/ASTNodeTypeCleaner.hpp> #include <language/ast/ASTSymbolTableBuilder.hpp> +#include <language/utils/CheckpointResumeRepository.hpp> #include <utils/Demangle.hpp> #include <pegtl/string_input.hpp> @@ -23,6 +24,7 @@ ASTModulesImporter{*ast}; \ ASTNodeTypeCleaner<language::import_instruction>{*ast}; \ \ + CheckpointResumeRepository::create(); \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ \ @@ -43,6 +45,7 @@ \ auto attributes = symbol->attributes(); \ auto value = std::get<decltype(expected_value)>(attributes.value()); \ + CheckpointResumeRepository::destroy(); \ \ REQUIRE(value == expected_value); \ } diff --git a/tests/test_OStream.cpp b/tests/test_OStream.cpp index 417db374ee366531623af397e2d62e78d1c83475..bc532e0d4963b0f33719ef942a2ed4aace8cc13c 100644 --- a/tests/test_OStream.cpp +++ b/tests/test_OStream.cpp @@ -11,7 +11,7 @@ TEST_CASE("OStream", "[language]") { SECTION("null ostream") { - std::shared_ptr os = std::make_shared<OStream>(); + std::shared_ptr os = std::make_shared<OStream>(OStream::Type::std_ostream); REQUIRE_NOTHROW(os << "foo" << 3 << " bar"); } @@ -20,7 +20,7 @@ TEST_CASE("OStream", "[language]") { std::stringstream sstr; - std::shared_ptr os = std::make_shared<OStream>(sstr); + std::shared_ptr os = std::make_shared<OStream>(sstr, OStream::Type::std_ofstream); os << "foo" << 3 << " bar"; REQUIRE(sstr.str() == "foo3 bar");