#include <utils/checkpointing/Resume.hpp>

#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_HDF5

#include <utils/HighFivePugsUtils.hpp>

#include <language/ast/ASTExecutionStack.hpp>
#include <language/utils/SymbolTable.hpp>

#include <iostream>
#endif   // PUGS_HAS_HDF5

#include <language/utils/ASTCheckpointsInfo.hpp>
#include <utils/Exceptions.hpp>

#ifdef PUGS_HAS_HDF5

#include <mesh/Connectivity.hpp>
#include <utils/RandomEngine.hpp>
#include <utils/checkpointing/ResumeUtils.hpp>
#include <utils/checkpointing/ResumingData.hpp>
#include <utils/checkpointing/ResumingManager.hpp>

#include <language/utils/CheckpointResumeRepository.hpp>

#include <map>

void
resume()
{
  try {
    ResumingData::create();
    HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly);

    HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");

    HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table");

    const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode();
    auto p_symbol_table   = p_node->m_symbol_table;

    ResumingManager& resuming_manager = ResumingManager::getInstance();

    resuming_manager.checkpointNumber() = file.getAttribute("checkpoint_number").read<uint64_t>() + 1;

    std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line "
              << rang::fgB::yellow << p_node->begin().line << rang::fg::reset << " [using " << rang::fgB::cyan
              << checkpoint.getAttribute("name").read<std::string>() << rang::fg::reset << "]\n";

    {
      HighFive::Group random_seed = checkpoint.getGroup("singleton/random_seed");
      RandomEngine::instance().setRandomSeed(random_seed.getAttribute("current_seed").read<uint64_t>());
    }

    {
      std::cout << rang::fgB::magenta << "Resume DualConnectivityManager NIY" << rang::fg::reset << '\n';
      std::cout << rang::fgB::magenta << "Resume DualMeshManager NIY" << rang::fg::reset << '\n';
    }

    ResumingData::instance().readData(checkpoint, p_symbol_table);

    bool finished = true;
    do {
      finished = true;

      for (auto symbol_name : saved_symbol_table.listAttributeNames()) {
        auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
        auto& attribute        = p_symbol->attributes();
        switch (attribute.dataType()) {
        case ASTNodeDataType::bool_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>();
          break;
        }
        case ASTNodeDataType::unsigned_int_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>();
          break;
        }
        case ASTNodeDataType::int_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>();
          break;
        }
        case ASTNodeDataType::double_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>();
          break;
        }
        case ASTNodeDataType::string_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>();
          break;
        }
        case ASTNodeDataType::vector_t: {
          switch (attribute.dataType().dimension()) {
          case 1: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>();
            break;
          }
          case 2: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>();
            break;
          }
          case 3: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>();
            break;
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
          }
            // LCOV_EXCL_STOP
          }
          break;
        }
        case ASTNodeDataType::matrix_t: {
          // LCOV_EXCL_START
          if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
          }
          // LCOV_EXCL_STOP
          switch (attribute.dataType().numberOfRows()) {
          case 1: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>();
            break;
          }
          case 2: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>();
            break;
          }
          case 3: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>();
            break;
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
          }
            // LCOV_EXCL_STOP
          }
          break;
        }
        case ASTNodeDataType::tuple_t: {
          switch (attribute.dataType().contentType()) {
          case ASTNodeDataType::bool_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>();
            break;
          }
          case ASTNodeDataType::unsigned_int_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>();
            break;
          }
          case ASTNodeDataType::int_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>();
            break;
          }
          case ASTNodeDataType::double_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>();
            break;
          }
          case ASTNodeDataType::string_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>();
            break;
          }
          case ASTNodeDataType::vector_t: {
            switch (attribute.dataType().contentType().dimension()) {
            case 1: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>();
              break;
            }
            case 2: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>();
              break;
            }
            case 3: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>();
              break;
            }
              // LCOV_EXCL_START
            default: {
              throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
            }
              // LCOV_EXCL_STOP
            }
            break;
          }
          case ASTNodeDataType::matrix_t: {
            // LCOV_EXCL_START
            if (attribute.dataType().contentType().numberOfRows() !=
                attribute.dataType().contentType().numberOfColumns()) {
              throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
            }
            // LCOV_EXCL_STOP
            switch (attribute.dataType().contentType().numberOfRows()) {
            case 1: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>();
              break;
            }
            case 2: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>();
              break;
            }
            case 3: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>();
              break;
            }
              // LCOV_EXCL_START
            default: {
              throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
            }
              // LCOV_EXCL_STOP
            }
            break;
          }
          default: {
            throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType()));
          }
          }
          break;
        }
        default: {
          throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType()));
        }
        }
      }

      if (saved_symbol_table.exist("embedded")) {
        HighFive::Group embedded = saved_symbol_table.getGroup("embedded");

        for (auto symbol_name : embedded.listObjectNames()) {
          auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
          if (p_symbol->attributes().dataType() == ASTNodeDataType::tuple_t) {
            HighFive::Group embedded_tuple_group = embedded.getGroup(symbol_name);
            const size_t number_of_components    = embedded_tuple_group.getNumberObjects();
            std::vector<EmbeddedData> embedded_tuple(number_of_components);

            for (size_t i_component = 0; i_component < number_of_components; ++i_component) {
              embedded_tuple[i_component] =
                CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType().contentType(),
                                                              p_symbol->name() + "/" + std::to_string(i_component),
                                                              saved_symbol_table);
            }
            p_symbol->attributes().value() = embedded_tuple;
          } else {
            p_symbol->attributes().value() =
              CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType(), p_symbol->name(),
                                                            saved_symbol_table);
          }
        }
      }

      const bool symbol_table_has_parent       = p_symbol_table->hasParentTable();
      const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table");

      Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent));

      if (symbol_table_has_parent and saved_symbol_table_has_parent) {
        p_symbol_table     = p_symbol_table->parentTable();
        saved_symbol_table = saved_symbol_table.getGroup("symbol table");

        finished = false;
      }

    } while (not finished);

    ResumingData::destroy();
  }
  catch (HighFive::Exception& e) {
    throw NormalError(e.what());
  }
}

#else   // PUGS_HAS_HDF5

void
resume()
{
  throw NormalError("checkpoint/resume mechanism requires HDF5");
}

#endif   // PUGS_HAS_HDF5
