#include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_predicate.hpp>

#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_HDF5

#include <dev/ParallelChecker.hpp>
#include <language/ast/ASTBuilder.hpp>
#include <language/ast/ASTExecutionStack.hpp>
#include <language/ast/ASTModulesImporter.hpp>
#include <language/ast/ASTNodeDataTypeBuilder.hpp>
#include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp>
#include <language/ast/ASTNodeExpressionBuilder.hpp>
#include <language/ast/ASTNodeTypeCleaner.hpp>
#include <language/ast/ASTSymbolTableBuilder.hpp>
#include <language/modules/MathModule.hpp>
#include <language/utils/ASTCheckpointsInfo.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <mesh/DualMeshType.hpp>
#include <utils/ExecutionStatManager.hpp>
#include <utils/checkpointing/DualMeshTypeHFType.hpp>

class ASTCheckpointsInfoTester
{
 private:
  ASTCheckpointsInfo m_ast_checkpoint_info;

 public:
  ASTCheckpointsInfoTester(const ASTNode& root_node) : m_ast_checkpoint_info(root_node) {}
  ~ASTCheckpointsInfoTester() = default;
};

#define RUN_AST(data)                                          \
  {                                                            \
    ExecutionStatManager::create();                            \
    ParallelChecker::create();                                 \
    CheckpointResumeRepository::create();                      \
                                                               \
    TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \
    auto ast = ASTBuilder::build(input);                       \
                                                               \
    ASTModulesImporter{*ast};                                  \
    ASTNodeTypeCleaner<language::import_instruction>{*ast};    \
                                                               \
    ASTSymbolTableBuilder{*ast};                               \
    ASTNodeDataTypeBuilder{*ast};                              \
                                                               \
    ASTNodeDeclarationToAffectationConverter{*ast};            \
    ASTNodeTypeCleaner<language::var_declaration>{*ast};       \
    ASTNodeTypeCleaner<language::fct_declaration>{*ast};       \
                                                               \
    ASTNodeExpressionBuilder{*ast};                            \
    ExecutionPolicy exec_policy;                               \
    ASTExecutionStack::create();                               \
    ASTCheckpointsInfoTester ast_cp_info_tester{*ast};         \
    ast->execute(exec_policy);                                 \
    ASTExecutionStack::destroy();                              \
    ast->m_symbol_table->clearValues();                        \
                                                               \
    CheckpointResumeRepository::destroy();                     \
    ParallelChecker::destroy();                                \
    ExecutionStatManager::destroy();                           \
    ast->m_symbol_table->clearValues();                        \
  }

#else   // PUGS_HAS_HDF5

#include <utils/checkpointing/Checkpoint.hpp>

#endif   // PUGS_HAS_HDF5

// clazy:excludeall=non-pod-global-static

TEST_CASE("checkpointing_Checkpoint_sequential", "[utils/checkpointing]")
{
#ifdef PUGS_HAS_HDF5

  ResumingManager::destroy();
  ResumingManager::create();

  std::string tmp_dirname;
  {
    {
      if (parallel::rank() == 0) {
        tmp_dirname = [&]() -> std::string {
          std::string temp_filename = std::filesystem::temp_directory_path() / "pugs_checkpointing_XXXXXX";
          return std::string{mkdtemp(&temp_filename[0])};
        }();
      }
      parallel::broadcast(tmp_dirname, 0);
    }
    std::filesystem::path path = tmp_dirname;
    const std::string filename = path / "checkpoint.h5";

    ResumingManager::getInstance().setFilename(filename);
  }

  std::string data = R"(
import math;
import mesh;
import scheme;

let f:R*R^3 -> R^3, (a, v) -> a * v;

let alpha:R, alpha = 3.2;
let u:R^3, u = [1,2,3];

let m:mesh, m = cartesianMesh(0, [1,1], (10,10));

let n:(N), n = (1,2,3,4);

let duals:(mesh), duals = (diamondDual(m), medianDual(m));

for(let i:N, i=0; i<3; ++i) {
  checkpoint();

  let g:R -> R^3, x -> [x, 1.2*x, 3];
  u = f(alpha,u);

  checkpoint();
}
)";

  const size_t initial_mesh_id         = GlobalVariableManager::instance().getMeshId();
  const size_t initial_connectivity_id = GlobalVariableManager::instance().getConnectivityId();

  RUN_AST(data);

  GlobalVariableManager::instance().setMeshId(initial_mesh_id);
  GlobalVariableManager::instance().setConnectivityId(initial_connectivity_id);

  {   // Check checkpoint file
    std::filesystem::path path = tmp_dirname;
    const std::string filename = path / "checkpoint.h5";

    HighFive::File file(filename, HighFive::File::ReadOnly);

    HighFive::Group checkpoint = file.getGroup("/resuming_checkpoint");

    REQUIRE(checkpoint.getAttribute("id").read<uint64_t>() == 1);
    REQUIRE(checkpoint.getAttribute("checkpoint_number").read<uint64_t>() == 5);
    REQUIRE(checkpoint.getAttribute("name").read<std::string>() == "checkpoint_5");

    HighFive::Group symbol_table0 = checkpoint.getGroup("symbol table");
    REQUIRE(symbol_table0.getAttribute("i").read<uint64_t>() == 2);

    HighFive::Group symbol_table1 = symbol_table0.getGroup("symbol table");
    REQUIRE(symbol_table1.getAttribute("alpha").read<double>() == 3.2);
    REQUIRE(symbol_table1.getAttribute("n").read<std::vector<uint64_t>>() == std::vector<uint64_t>{1, 2, 3, 4});
    REQUIRE(l2Norm(symbol_table1.getAttribute("u").read<TinyVector<3>>() - TinyVector<3>{32.768, 65.536, 98.304}) ==
            Catch::Approx(0).margin(1E-12));

    HighFive::Group embedded1 = symbol_table1.getGroup("embedded");

    HighFive::Group duals = embedded1.getGroup("duals");
    REQUIRE(duals.getAttribute("type").read<std::string>() == "(mesh)");
    HighFive::Group duals_0 = duals.getGroup("0");
    REQUIRE(duals_0.getAttribute("type").read<std::string>() == "mesh");
    REQUIRE(duals_0.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
    HighFive::Group duals_1 = duals.getGroup("1");
    REQUIRE(duals_1.getAttribute("type").read<std::string>() == "mesh");
    REQUIRE(duals_1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);

    HighFive::Group m = embedded1.getGroup("m");
    REQUIRE(m.getAttribute("type").read<std::string>() == "mesh");
    REQUIRE(m.getAttribute("id").read<uint64_t>() == initial_mesh_id);

    HighFive::Group singleton        = checkpoint.getGroup("singleton");
    HighFive::Group global_variables = singleton.getGroup("global_variables");
    REQUIRE(global_variables.getAttribute("connectivity_id").read<uint64_t>() == initial_connectivity_id + 3);
    REQUIRE(global_variables.getAttribute("mesh_id").read<uint64_t>() == initial_mesh_id + 3);
    HighFive::Group execution_info = singleton.getGroup("execution_info");
    REQUIRE(execution_info.getAttribute("run_number").read<uint64_t>() == 1);

    HighFive::Group connectivity  = checkpoint.getGroup("connectivity");
    HighFive::Group connectivity0 = connectivity.getGroup(std::to_string(initial_connectivity_id));
    REQUIRE(connectivity0.getAttribute("dimension").read<uint64_t>() == 2);
    REQUIRE(connectivity0.getAttribute("id").read<uint64_t>() == initial_connectivity_id);
    REQUIRE(connectivity0.getAttribute("type").read<std::string>() == "unstructured");

    HighFive::Group connectivity1 = connectivity.getGroup(std::to_string(initial_connectivity_id + 1));
    REQUIRE(connectivity1.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 1);
    REQUIRE(connectivity1.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
    REQUIRE(connectivity1.getAttribute("type").read<std::string>() == "dual_connectivity");
    REQUIRE(connectivity1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);

    HighFive::Group connectivity2 = connectivity.getGroup(std::to_string(initial_connectivity_id + 2));
    REQUIRE(connectivity2.getAttribute("id").read<uint64_t>() == initial_connectivity_id + 2);
    REQUIRE(connectivity2.getAttribute("primal_connectivity_id").read<uint64_t>() == initial_connectivity_id);
    REQUIRE(connectivity2.getAttribute("type").read<std::string>() == "dual_connectivity");
    REQUIRE(connectivity2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);

    HighFive::Group mesh  = checkpoint.getGroup("mesh");
    HighFive::Group mesh0 = mesh.getGroup(std::to_string(initial_mesh_id));
    REQUIRE(mesh0.getAttribute("connectivity").read<uint64_t>() == initial_connectivity_id);
    REQUIRE(mesh0.getAttribute("dimension").read<uint64_t>() == 2);
    REQUIRE(mesh0.getAttribute("id").read<uint64_t>() == initial_mesh_id);
    REQUIRE(mesh0.getAttribute("type").read<std::string>() == "polygonal");

    HighFive::Group mesh1 = mesh.getGroup(std::to_string(initial_mesh_id + 1));
    REQUIRE(mesh1.getAttribute("id").read<uint64_t>() == initial_mesh_id + 1);
    REQUIRE(mesh1.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
    REQUIRE(mesh1.getAttribute("type").read<std::string>() == "dual_mesh");
    REQUIRE(mesh1.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Diamond);

    HighFive::Group mesh2 = mesh.getGroup(std::to_string(initial_mesh_id + 2));
    REQUIRE(mesh2.getAttribute("id").read<uint64_t>() == initial_mesh_id + 2);
    REQUIRE(mesh2.getAttribute("primal_mesh_id").read<uint64_t>() == initial_mesh_id);
    REQUIRE(mesh2.getAttribute("type").read<std::string>() == "dual_mesh");
    REQUIRE(mesh2.getAttribute("type_of_dual").read<DualMeshType>() == DualMeshType::Median);

    HighFive::Group functions = checkpoint.getGroup("functions");
    HighFive::Group f         = functions.getGroup("f");
    REQUIRE(f.getAttribute("id").read<uint64_t>() == 0);
    REQUIRE(f.getAttribute("symbol_table_id").read<uint64_t>() == 1);
    HighFive::Group g = functions.getGroup("g");
    REQUIRE(g.getAttribute("id").read<uint64_t>() == 1);
    REQUIRE(g.getAttribute("symbol_table_id").read<uint64_t>() == 0);
  }

  parallel::barrier();
  if (parallel::rank() == 0) {
    std::filesystem::remove_all(std::filesystem::path{tmp_dirname});
  }

  // Revert to default value
  ResumingManager::getInstance().setFilename("checkpoint.h5");

#else    // PUGS_HAS_HDF5
  REQUIRE_THROWS_WITH(checkpoint(), "error: checkpoint/resume mechanism requires HDF5");
#endif   // PUGS_HAS_HDF5
}