#include <utils/checkpointing/Checkpoint.hpp>

#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_HDF5

#include <utils/HighFivePugsUtils.hpp>

#include <language/ast/ASTExecutionStack.hpp>
#include <language/utils/SymbolTable.hpp>

#include <iostream>
#include <map>
#endif   // PUGS_HAS_HDF5

#include <language/utils/ASTCheckpointsInfo.hpp>
#include <utils/Exceptions.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/checkpointing/ResumingManager.hpp>

#ifdef PUGS_HAS_HDF5

#include <algebra/LinearSolverOptions.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <language/utils/DataHandler.hpp>
#include <mesh/MeshVariant.hpp>
#include <utils/GlobalVariableManager.hpp>
#include <utils/RandomEngine.hpp>

#include <utils/checkpointing/LinearSolverOptionsHFType.hpp>

void
checkpoint()
{
  try {
    auto create_props = HighFive::FileCreateProps{};
    create_props.add(HighFive::FileSpaceStrategy(H5F_FSPACE_STRATEGY_FSM_AGGR, true, 0));

    HighFive::FileAccessProps fapl;
    fapl.add(HighFive::MPIOFileAccess{MPI_COMM_WORLD, MPI_INFO_NULL});
    fapl.add(HighFive::MPIOCollectiveMetadata{});

    uint64_t& checkpoint_number = ResumingManager::getInstance().checkpointNumber();

    const auto file_openmode = (checkpoint_number == 0) ? HighFive::File::Truncate : HighFive::File::ReadWrite;

    HighFive::File file("checkpoint.h5", file_openmode, create_props, fapl);

    std::string checkpoint_name = "checkpoint_" + std::to_string(checkpoint_number);

    HighFive::Group checkpoint = file.createGroup(checkpoint_name);

    uint64_t checkpoint_id =
      ASTCheckpointsInfo::getInstance().getCheckpointId((ASTExecutionStack::getInstance().currentNode()));

    std::string time = [] {
      std::ostringstream os;
      auto t = std::time(nullptr);
      os << std::put_time(std::localtime(&t), "%c");
      return os.str();
    }();

    checkpoint.createAttribute("creation_date", time);
    checkpoint.createAttribute("name", checkpoint_name);
    checkpoint.createAttribute("id", checkpoint_id);
    checkpoint.createAttribute("data.pgs", ASTExecutionStack::getInstance().fileContent());

    {
      HighFive::Group global_variables_group = checkpoint.createGroup("singleton/global_variables");
      global_variables_group.createAttribute("connectivity_id", GlobalVariableManager::instance().getConnectivityId());
      global_variables_group.createAttribute("mesh_id", GlobalVariableManager::instance().getMeshId());
    }
    {
      HighFive::Group random_seed_group = checkpoint.createGroup("singleton/random_seed");
      random_seed_group.createAttribute("current_seed", RandomEngine::instance().getCurrentSeed());
    }
    {
      HighFive::Group execution_info_group = checkpoint.createGroup("singleton/execution_info");
      execution_info_group.createAttribute("run_number", ExecutionStatManager::getInstance().runNumber());
      execution_info_group.createAttribute("cumulative_elapse_time",
                                           ExecutionStatManager::getInstance().getCumulativeElapseTime());
      execution_info_group.createAttribute("cumulative_total_cpu_time",
                                           ExecutionStatManager::getInstance().getCumulativeTotalCPUTime());
    }
    {
      HighFive::Group linear_solver_options_default_group =
        checkpoint.createGroup("singleton/linear_solver_options_default");

      const LinearSolverOptions& default_options = LinearSolverOptions::default_options;

      linear_solver_options_default_group.createAttribute("epsilon", default_options.epsilon());
      linear_solver_options_default_group.createAttribute("maximum_iteration", default_options.maximumIteration());
      linear_solver_options_default_group.createAttribute("verbose", default_options.verbose());

      linear_solver_options_default_group.createAttribute("library", default_options.library());
      linear_solver_options_default_group.createAttribute("method", default_options.method());
      linear_solver_options_default_group.createAttribute("precond", default_options.precond());
    }
    {
      std::cout << rang::fgB::magenta << "Checkpoint DualConnectivityManager NIY" << rang::fg::reset << '\n';
      std::cout << rang::fgB::magenta << "Checkpoint DualMeshManager NIY" << rang::fg::reset << '\n';
    }

    std::shared_ptr<const SymbolTable> p_symbol_table = ASTExecutionStack::getInstance().currentNode().m_symbol_table;
    auto symbol_table_group                           = checkpoint;
    size_t symbol_table_id                            = 0;
    while (p_symbol_table.use_count() > 0) {
      symbol_table_group = symbol_table_group.createGroup("symbol table");

      const SymbolTable& symbol_table = *p_symbol_table;

      const auto& symbol_list = symbol_table.symbolList();

      for (auto& symbol : symbol_list) {
        switch (symbol.attributes().dataType()) {
        case ASTNodeDataType::builtin_function_t:
        case ASTNodeDataType::type_name_id_t: {
          break;
        }
        case ASTNodeDataType::function_t: {
          HighFive::Group function_group = checkpoint.createGroup("functions/" + symbol.name());
          function_group.createAttribute("id", std::get<size_t>(symbol.attributes().value()));
          function_group.createAttribute("symbol_table_id", symbol_table_id);
          break;
        }
        default: {
          if ((symbol_table.has(symbol.name(), ASTExecutionStack::getInstance().currentNode().begin())) and
              (not symbol.attributes().isModuleVariable())) {
            std::visit(
              [&](auto&& data) {
                using DataT = std::decay_t<decltype(data)>;
                if constexpr (std::is_same_v<DataT, std::monostate>) {
                } else if constexpr ((std::is_arithmetic_v<DataT>) or (std::is_same_v<DataT, std::string>) or
                                     (is_tiny_vector_v<DataT>) or (is_tiny_matrix_v<DataT>)) {
                  symbol_table_group.createAttribute(symbol.name(), data);
                } else if constexpr (std::is_same_v<DataT, EmbeddedData>) {
                  CheckpointResumeRepository::instance().checkpoint(symbol.attributes().dataType(), symbol.name(), data,
                                                                    file, checkpoint, symbol_table_group);
                } else if constexpr (is_std_vector_v<DataT>) {
                  using value_type = typename DataT::value_type;
                  if constexpr ((std::is_arithmetic_v<value_type>) or (std::is_same_v<value_type, std::string>) or
                                (is_tiny_vector_v<value_type>) or (is_tiny_matrix_v<value_type>)) {
                    symbol_table_group.createAttribute(symbol.name(), data);
                  } else if constexpr (std::is_same_v<value_type, EmbeddedData>) {
                    symbol_table_group.createGroup("embedded/" + symbol.name())
                      .createAttribute("type", dataTypeName(symbol.attributes().dataType()));
                    for (size_t i = 0; i < data.size(); ++i) {
                      CheckpointResumeRepository::instance().checkpoint(symbol.attributes().dataType().contentType(),
                                                                        symbol.name() + "/" + std::to_string(i),
                                                                        data[i], file, checkpoint, symbol_table_group);
                    }
                  } else {
                    throw UnexpectedError("unexpected data type");
                  }
                } else {
                  throw UnexpectedError("unexpected data type");
                }
              },
              symbol.attributes().value());
          }
        }
        }
      }

      p_symbol_table = symbol_table.parentTable();
      ++symbol_table_id;
    }

    if (file.exist("last_checkpoint")) {
      file.unlink("last_checkpoint");
    }
    file.createHardLink("last_checkpoint", checkpoint);

    if (file.exist("resuming_checkpoint")) {
      file.unlink("resuming_checkpoint");
    }
    file.createHardLink("resuming_checkpoint", checkpoint);

    if (file.hasAttribute("checkpoint_number")) {
      file.deleteAttribute("checkpoint_number");
    }
    file.createAttribute("checkpoint_number", checkpoint_number);

    ++checkpoint_number;
  }
  catch (HighFive::Exception& e) {
    throw NormalError(e.what());
  }
}

#else   // PUGS_HAS_HDF5

void
checkpoint()
{
  throw NormalError("checkpoint/resume mechanism requires HDF5");
}

#endif   // PUGS_HAS_HDF5