#include <utils/checkpointing/Resume.hpp>

#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_HDF5

#include <algebra/EigenvalueSolverOptions.hpp>
#include <algebra/LinearSolverOptions.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <language/utils/SymbolTable.hpp>
#include <mesh/Connectivity.hpp>
#include <utils/Exceptions.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/HighFivePugsUtils.hpp>
#include <utils/RandomEngine.hpp>
#include <utils/checkpointing/EigenvalueSolverOptionsHFType.hpp>
#include <utils/checkpointing/LinearSolverOptionsHFType.hpp>
#include <utils/checkpointing/ParallelCheckerHFType.hpp>
#include <utils/checkpointing/PartitionerOptionsHFType.hpp>
#include <utils/checkpointing/ResumingData.hpp>
#include <utils/checkpointing/ResumingManager.hpp>

#include <iostream>

void
resume()
{
  try {
    HighFive::SilenceHDF5 m_silence_hdf5{true};
    checkpointing::ResumingData::create();
    HighFive::File file(ResumingManager::getInstance().filename(), HighFive::File::ReadOnly);

    HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");

    HighFive::Group saved_symbol_table = checkpoint.getGroup("symbol table");

    const ASTNode* p_node = &ASTExecutionStack::getInstance().currentNode();
    auto p_symbol_table   = p_node->m_symbol_table;

    ResumingManager& resuming_manager = ResumingManager::getInstance();

    resuming_manager.checkpointNumber() = checkpoint.getAttribute("checkpoint_number").read<uint64_t>() + 1;

    std::cout << " * " << rang::fgB::green << "Resuming " << rang::fg::reset << "execution at line "
              << rang::fgB::yellow << p_node->begin().line << rang::fg::reset << " [using " << rang::fgB::cyan
              << checkpoint.getAttribute("name").read<std::string>() << rang::fg::reset << "]\n";

    {
      HighFive::Group random_seed_group = checkpoint.getGroup("singleton/random_seed");
      RandomEngine::instance().setRandomSeed(random_seed_group.getAttribute("current_seed").read<uint64_t>());
    }
    {
      HighFive::Group global_variables_group = checkpoint.getGroup("singleton/execution_info");
      const size_t run_number                = global_variables_group.getAttribute("run_number").read<size_t>();
      const double cumulative_elapse_time =
        global_variables_group.getAttribute("cumulative_elapse_time").read<double>();
      const double cumulative_total_cpu_time =
        global_variables_group.getAttribute("cumulative_total_cpu_time").read<double>();

      ExecutionStatManager::getInstance().setRunNumber(run_number + 1);
      ExecutionStatManager::getInstance().setPreviousCumulativeElapseTime(cumulative_elapse_time);
      ExecutionStatManager::getInstance().setPreviousCumulativeTotalCPUTime(cumulative_total_cpu_time);
    }
    {
      HighFive::Group random_seed_group = checkpoint.getGroup("singleton/parallel_checker");
      // Ordering is important! Must set mode before changing the tag (changing mode is not allowed if tag!=0)
      ParallelChecker::instance().setMode(random_seed_group.getAttribute("mode").read<ParallelChecker::Mode>());
      ParallelChecker::instance().setTag(random_seed_group.getAttribute("tag").read<size_t>());
    }
    {
      HighFive::Group linear_solver_options_default_group =
        checkpoint.getGroup("singleton/linear_solver_options_default");

      LinearSolverOptions& default_options = LinearSolverOptions::default_options;

      default_options.epsilon() = linear_solver_options_default_group.getAttribute("epsilon").read<double>();
      default_options.maximumIteration() =
        linear_solver_options_default_group.getAttribute("maximum_iteration").read<size_t>();
      default_options.verbose() = linear_solver_options_default_group.getAttribute("verbose").read<bool>();

      default_options.library() = linear_solver_options_default_group.getAttribute("library").read<LSLibrary>();
      default_options.method()  = linear_solver_options_default_group.getAttribute("method").read<LSMethod>();
      default_options.precond() = linear_solver_options_default_group.getAttribute("precond").read<LSPrecond>();
    }
    {
      HighFive::Group eigenvalue_solver_options_default_group =
        checkpoint.getGroup("singleton/eigenvalue_solver_options_default");

      EigenvalueSolverOptions& default_options = EigenvalueSolverOptions::default_options;

      default_options.library() = eigenvalue_solver_options_default_group.getAttribute("library").read<ESLibrary>();
    }
    {
      HighFive::Group partitioner_options_default_group = checkpoint.getGroup("singleton/partitioner_options_default");

      PartitionerOptions& default_options = PartitionerOptions::default_options;

      default_options.library() = partitioner_options_default_group.getAttribute("library").read<PartitionerLibrary>();
    }

    checkpointing::ResumingData::instance().readData(checkpoint, p_symbol_table);

    bool finished = true;
    do {
      finished = true;

      for (auto symbol_name : saved_symbol_table.listAttributeNames()) {
        auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
        auto& attribute        = p_symbol->attributes();
        switch (attribute.dataType()) {
        case ASTNodeDataType::bool_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<bool>();
          break;
        }
        case ASTNodeDataType::unsigned_int_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<uint64_t>();
          break;
        }
        case ASTNodeDataType::int_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<int64_t>();
          break;
        }
        case ASTNodeDataType::double_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<double_t>();
          break;
        }
        case ASTNodeDataType::string_t: {
          attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::string>();
          break;
        }
        case ASTNodeDataType::vector_t: {
          switch (attribute.dataType().dimension()) {
          case 1: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<1>>();
            break;
          }
          case 2: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<2>>();
            break;
          }
          case 3: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyVector<3>>();
            break;
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
          }
            // LCOV_EXCL_STOP
          }
          break;
        }
        case ASTNodeDataType::matrix_t: {
          // LCOV_EXCL_START
          if (attribute.dataType().numberOfRows() != attribute.dataType().numberOfColumns()) {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
          }
          // LCOV_EXCL_STOP
          switch (attribute.dataType().numberOfRows()) {
          case 1: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<1>>();
            break;
          }
          case 2: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<2>>();
            break;
          }
          case 3: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<TinyMatrix<3>>();
            break;
          }
            // LCOV_EXCL_START
          default: {
            throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected matrix dimension");
          }
            // LCOV_EXCL_STOP
          }
          break;
        }
        case ASTNodeDataType::tuple_t: {
          switch (attribute.dataType().contentType()) {
          case ASTNodeDataType::bool_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<bool>>();
            break;
          }
          case ASTNodeDataType::unsigned_int_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<uint64_t>>();
            break;
          }
          case ASTNodeDataType::int_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<int64_t>>();
            break;
          }
          case ASTNodeDataType::double_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<double_t>>();
            break;
          }
          case ASTNodeDataType::string_t: {
            attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<std::string>>();
            break;
          }
          case ASTNodeDataType::vector_t: {
            switch (attribute.dataType().contentType().dimension()) {
            case 1: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<1>>>();
              break;
            }
            case 2: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<2>>>();
              break;
            }
            case 3: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyVector<3>>>();
              break;
            }
              // LCOV_EXCL_START
            default: {
              throw UnexpectedError(dataTypeName(attribute.dataType()) + " unexpected vector dimension");
            }
              // LCOV_EXCL_STOP
            }
            break;
          }
          case ASTNodeDataType::matrix_t: {
            // LCOV_EXCL_START
            if (attribute.dataType().contentType().numberOfRows() !=
                attribute.dataType().contentType().numberOfColumns()) {
              throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
            }
            // LCOV_EXCL_STOP
            switch (attribute.dataType().contentType().numberOfRows()) {
            case 1: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<1>>>();
              break;
            }
            case 2: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<2>>>();
              break;
            }
            case 3: {
              attribute.value() = saved_symbol_table.getAttribute(symbol_name).read<std::vector<TinyMatrix<3>>>();
              break;
            }
              // LCOV_EXCL_START
            default: {
              throw UnexpectedError(dataTypeName(attribute.dataType().contentType()) + " unexpected matrix dimension");
            }
              // LCOV_EXCL_STOP
            }
            break;
          }
            // LCOV_EXCL_START
          default: {
            throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType().contentType()));
          }
            // LCOV_EXCL_STOP
          }
          break;
        }
          // LCOV_EXCL_START
        default: {
          throw NotImplementedError(symbol_name + " of type " + dataTypeName(attribute.dataType()));
        }
          // LCOV_EXCL_STOP
        }
      }

      if (saved_symbol_table.exist("embedded")) {
        HighFive::Group embedded = saved_symbol_table.getGroup("embedded");

        for (auto symbol_name : embedded.listObjectNames()) {
          auto [p_symbol, found] = p_symbol_table->find(symbol_name, p_node->begin());
          if (p_symbol->attributes().dataType() == ASTNodeDataType::tuple_t) {
            HighFive::Group embedded_tuple_group = embedded.getGroup(symbol_name);
            const size_t number_of_components    = embedded_tuple_group.getNumberObjects();
            std::vector<EmbeddedData> embedded_tuple(number_of_components);

            for (size_t i_component = 0; i_component < number_of_components; ++i_component) {
              embedded_tuple[i_component] =
                CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType().contentType(),
                                                              p_symbol->name() + "/" + std::to_string(i_component),
                                                              saved_symbol_table);
            }
            p_symbol->attributes().value() = embedded_tuple;
          } else {
            p_symbol->attributes().value() =
              CheckpointResumeRepository::instance().resume(p_symbol->attributes().dataType(), p_symbol->name(),
                                                            saved_symbol_table);
          }
        }
      }

      const bool symbol_table_has_parent       = p_symbol_table->hasParentTable();
      const bool saved_symbol_table_has_parent = saved_symbol_table.exist("symbol table");

      Assert(not(symbol_table_has_parent xor saved_symbol_table_has_parent));

      if (symbol_table_has_parent and saved_symbol_table_has_parent) {
        p_symbol_table     = p_symbol_table->parentTable();
        saved_symbol_table = saved_symbol_table.getGroup("symbol table");

        finished = false;
      }

    } while (not finished);

    checkpointing::ResumingData::destroy();
  }
  // LCOV_EXCL_START
  catch (HighFive::Exception& e) {
    throw NormalError(e.what());
  }
  // LCOV_EXCL_STOP
}

#else   // PUGS_HAS_HDF5

#include <utils/Exceptions.hpp>

void
resume()
{
  throw NormalError("checkpoint/resume mechanism requires HDF5");
}

#endif   // PUGS_HAS_HDF5