#include <Partitioner.hpp>
#include <Messenger.hpp>
#include <pastis_config.hpp>

#include <PastisOStream.hpp>

#ifdef PASTIS_HAS_MPI

#define IDXTYPEWIDTH 64
#define REALTYPEWIDTH 64
#include <parmetis.h>


#include <vector>


Array<int> Partitioner::partition(const CSRGraph& graph)
{
  pout() << "Partitioning graph into "
         << rang::style::bold << commSize() << rang::style::reset
         << " parts\n";

  int wgtflag = 0;
  int numflag = 0;
  int ncon = 1;
  int npart= commSize();
  std::vector<float> tpwgts;
  for (int i_part=0; i_part<npart; ++i_part) {
    tpwgts.push_back(1./npart);
  }
  std::vector<float> ubvec{1.05};
  std::vector<int> options{1,1,0};
  int edgecut = 0;
  Array<int> part(0);

  MPI_Group world_group;
  MPI_Comm_group(MPI_COMM_WORLD, &world_group);

  MPI_Group mesh_group;
  std::vector<int> group_ranks{0};
  MPI_Group_incl(world_group, group_ranks.size(), &(group_ranks[0]), &mesh_group);

  MPI_Comm parmetis_comm;
  MPI_Comm_create_group(MPI_COMM_WORLD, mesh_group, 1, &parmetis_comm);

  int local_number_of_cells = graph.entries().size()-1;

  if (commRank() ==0) {
    part = Array<int>(local_number_of_cells);
    std::vector<int> vtxdist{0,local_number_of_cells};

    static_assert(std::is_same<int, int>());

    const Array<int>& entries = graph.entries();
    const Array<int>& neighbors = graph.neighbors();

    int result
        = ParMETIS_V3_PartKway(&(vtxdist[0]), &(entries[0]), &(neighbors[0]),
                               NULL, NULL, &wgtflag, &numflag,
                               &ncon, &npart, &(tpwgts[0]), &(ubvec[0]),
                               &(options[0]), &edgecut, &(part[0]), &parmetis_comm);
    if (result == METIS_ERROR) {
      perr() << "Metis Error\n";
      std::exit(1);
    }

    MPI_Comm_free(&parmetis_comm);
  }

  MPI_Group_free(&mesh_group);

  return part;
}

#else // PASTIS_HAS_MPI

Array<int> Partitioner::partition(const CSRGraph& graph)
{
  return Array<int>(0);
}

#endif // PASTIS_HAS_MPI