#include <utils/PTScotchPartitioner.hpp>

#include <utils/pugs_config.hpp>

#ifdef PUGS_HAS_PTSCOTCH

#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <ptscotch.h>

#include <iostream>
#include <vector>

Array<int>
PTScotchPartitioner::partition(const CRSGraph& graph)
{
  std::cout << "Partitioning graph into " << rang::style::bold << parallel::size() << rang::style::reset
            << " parts using " << rang::fgB::green << "PTScotch" << rang::fg::reset << '\n';

  MPI_Group world_group;
  MPI_Comm_group(parallel::Messenger::getInstance().comm(), &world_group);

  MPI_Group mesh_group;
  std::vector<int> group_ranks = [&]() {
    Array<int> graph_node_owners = parallel::allGather(static_cast<int>(graph.numberOfNodes()));
    std::vector<int> grp_ranks;
    grp_ranks.reserve(graph_node_owners.size());
    for (size_t i = 0; i < graph_node_owners.size(); ++i) {
      if (graph_node_owners[i] > 0) {
        grp_ranks.push_back(i);
      }
    }
    return grp_ranks;
  }();

  MPI_Group_incl(world_group, group_ranks.size(), &(group_ranks[0]), &mesh_group);

  MPI_Comm partitioning_comm;
  MPI_Comm_create_group(parallel::Messenger::getInstance().comm(), mesh_group, 1, &partitioning_comm);

  Array<int> partition;
  if (graph.numberOfNodes() > 0) {
    SCOTCH_Strat scotch_strategy;
    SCOTCH_Dgraph scotch_grah;

    SCOTCH_stratInit(&scotch_strategy);   // use default strategy
    SCOTCH_dgraphInit(&scotch_grah, partitioning_comm);

    const Array<const int>& entries   = graph.entries();
    const Array<const int>& neighbors = graph.neighbors();

    static_assert(std::is_same_v<int, SCOTCH_Num>);

    SCOTCH_Num* entries_ptr   = const_cast<int*>(&(entries[0]));
    SCOTCH_Num* neighbors_ptr = const_cast<int*>(&(neighbors[0]));

    if (SCOTCH_dgraphBuild(&scotch_grah,
                           0,                          // 0 for C-like arrays
                           graph.numberOfNodes(),      //
                           graph.numberOfNodes(),      //
                           entries_ptr,                //
                           nullptr,                    //
                           nullptr,                    //
                           nullptr,                    // optional local node label array
                           graph.neighbors().size(),   //
                           graph.neighbors().size(),   //
                           neighbors_ptr,              //
                           nullptr,                    //
                           nullptr) != 0) {
      //  LCOV_EXCL_START
      throw UnexpectedError("PTScotch invalid graph");
      //   LCOV_EXCL_STOP
    }

    partition = Array<int>(graph.numberOfNodes());

    SCOTCH_Num* partition_ptr = static_cast<SCOTCH_Num*>(&(partition[0]));

    if (SCOTCH_dgraphPart(&scotch_grah,       //
                          parallel::size(),   //
                          &scotch_strategy,   //
                          partition_ptr) != 0) {
      // LCOV_EXCL_START
      throw UnexpectedError("PTScotch Error");
      // LCOV_EXCL_STOP
    }

    SCOTCH_dgraphExit(&scotch_grah);
    SCOTCH_stratExit(&scotch_strategy);
  }

  MPI_Comm_free(&partitioning_comm);
  MPI_Group_free(&mesh_group);
  MPI_Group_free(&world_group);

  return partition;
}

#else   // PUGS_HAS_PTSCOTCH

Array<int>
PTScotchPartitioner::partition(const CRSGraph& graph)
{
  Array<int> partition{graph.entries().size() - 1};
  partition.fill(0);
  return partition;
}

#endif   // PUGS_HAS_PTSCOTCH