diff --git a/src/utils/checkpointing/ResumingData.cpp b/src/utils/checkpointing/ResumingData.cpp index 39b88a6d892582527fdfd8a3280200dc0f8a2fc5..5743e98fd04185018b67a2f40ffb6f85caa7b70d 100644 --- a/src/utils/checkpointing/ResumingData.cpp +++ b/src/utils/checkpointing/ResumingData.cpp @@ -1,5 +1,6 @@ #include <utils/checkpointing/ResumingData.hpp> +#include <language/utils/SymbolTable.hpp> #include <mesh/ConnectivityDescriptor.hpp> #include <mesh/Mesh.hpp> #include <utils/Exceptions.hpp> @@ -231,10 +232,58 @@ ResumingData::_getMeshVariantList(const HighFive::Group& checkpoint) } void -ResumingData::readData(const HighFive::Group& checkpoint) +ResumingData::_getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table) +{ + size_t symbol_table_id = 0; + const HighFive::Group function_group = checkpoint.getGroup("functions"); + while (p_symbol_table.use_count() > 0) { + for (auto symbol : p_symbol_table->symbolList()) { + if (symbol.attributes().dataType() == ASTNodeDataType::function_t) { + if (not function_group.exist(symbol.name())) { + std::ostringstream error_msg; + error_msg << "cannot find function " << rang::fgB::yellow << symbol.name() << rang::fg::reset << " in " + << rang::fgB::cyan << checkpoint.getFile().getName() << rang::fg::reset; + throw NormalError(error_msg.str()); + } else { + const HighFive::Group function = function_group.getGroup(symbol.name()); + const size_t stored_function_id = function.getAttribute("id").read<size_t>(); + const size_t function_id = std::get<size_t>(symbol.attributes().value()); + if (symbol_table_id != function.getAttribute("symbol_table_id").read<size_t>()) { + std::ostringstream error_msg; + error_msg << "symbol table of function " << rang::fgB::yellow << symbol.name() << rang::fg::reset + << " does not match the one stored in " << rang::fgB::cyan << checkpoint.getFile().getName() + << rang::fg::reset; + throw NormalError(error_msg.str()); + } else if (function_id != stored_function_id) { + std::ostringstream error_msg; + error_msg << "id (" << function_id << ") of function " << rang::fgB::yellow << symbol.name() + << rang::fg::reset << " does not match the one stored in " << rang::fgB::cyan + << checkpoint.getFile().getName() << rang::fg::reset << "(" << stored_function_id << ")"; + throw NormalError(error_msg.str()); + } else { + if (m_id_to_function_symbol_id_map.contains(function_id)) { + std::ostringstream error_msg; + error_msg << "id (" << function_id << ") of function " << rang::fgB::yellow << symbol.name() + << rang::fg::reset << " is duplicated"; + throw UnexpectedError(error_msg.str()); + } + m_id_to_function_symbol_id_map[function_id] = + std::make_shared<FunctionSymbolId>(function_id, p_symbol_table); + } + } + } + } + p_symbol_table = p_symbol_table->parentTable(); + ++symbol_table_id; + } +} + +void +ResumingData::readData(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table) { this->_getConnectivityList(checkpoint); this->_getMeshVariantList(checkpoint); + this->_getFunctionIds(checkpoint, p_symbol_table); } const std::shared_ptr<const IConnectivity>& @@ -259,6 +308,17 @@ ResumingData::meshVariant(const size_t mesh_id) const } } +const std::shared_ptr<const FunctionSymbolId>& +ResumingData::functionSymbolId(const size_t function_symbol_id) const +{ + auto i_id_to_function_symbol_id = m_id_to_function_symbol_id_map.find(function_symbol_id); + if (i_id_to_function_symbol_id == m_id_to_function_symbol_id_map.end()) { + throw UnexpectedError("cannot find function_symbol_id of id " + std::to_string(function_symbol_id)); + } else { + return i_id_to_function_symbol_id->second; + } +} + void ResumingData::create() { diff --git a/src/utils/checkpointing/ResumingData.hpp b/src/utils/checkpointing/ResumingData.hpp index b7a25221e926ad2cb97ef8b61807c140ba5d8d5d..9da8cd067e2896cd4f0428072a77c27f3c91833f 100644 --- a/src/utils/checkpointing/ResumingData.hpp +++ b/src/utils/checkpointing/ResumingData.hpp @@ -8,12 +8,15 @@ class IConnectivity; class MeshVariant; +class SymbolTable; +class FunctionSymbolId; class ResumingData { private: std::map<size_t, std::shared_ptr<const IConnectivity>> m_id_to_iconnectivity_map; std::map<size_t, std::shared_ptr<const MeshVariant>> m_id_to_mesh_variant_map; + std::map<size_t, std::shared_ptr<const FunctionSymbolId>> m_id_to_function_symbol_id_map; ResumingData() = default; ~ResumingData() = default; @@ -25,12 +28,14 @@ class ResumingData void _getConnectivityList(const HighFive::Group& checkpoint); void _getMeshVariantList(const HighFive::Group& checkpoint); + void _getFunctionIds(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table); public: - void readData(const HighFive::Group& checkpoint); + void readData(const HighFive::Group& checkpoint, std::shared_ptr<SymbolTable> p_symbol_table); const std::shared_ptr<const IConnectivity>& iConnectivity(const size_t connectivity_id) const; const std::shared_ptr<const MeshVariant>& meshVariant(const size_t mesh_id) const; + const std::shared_ptr<const FunctionSymbolId>& functionSymbolId(const size_t function_symbol_id) const; static void create(); static void destroy();