#include <catch2/catch_test_macros.hpp>

#include <utils/Messenger.hpp>
#include <utils/Partitioner.hpp>

#include <set>

// clazy:excludeall=non-pod-global-static

TEST_CASE("Partitioner", "[utils]")
{
  SECTION("one graph split to all")
  {
    Partitioner partitioner;

    std::vector<int> entries_vector;
    std::vector<int> neighbors_vector;

    entries_vector.push_back(neighbors_vector.size());

    if (parallel::rank() == 0) {
      neighbors_vector.push_back(1);
      neighbors_vector.push_back(2);
      neighbors_vector.push_back(4);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(0);
      neighbors_vector.push_back(3);
      neighbors_vector.push_back(5);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(0);
      neighbors_vector.push_back(2);
      neighbors_vector.push_back(5);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(0);
      neighbors_vector.push_back(2);
      neighbors_vector.push_back(5);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(3);
      neighbors_vector.push_back(5);
      neighbors_vector.push_back(7);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(3);
      neighbors_vector.push_back(5);
      neighbors_vector.push_back(6);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(5);
      neighbors_vector.push_back(6);
      neighbors_vector.push_back(7);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(3);
      neighbors_vector.push_back(5);
      neighbors_vector.push_back(7);
      entries_vector.push_back(neighbors_vector.size());

      neighbors_vector.push_back(4);
      neighbors_vector.push_back(6);
      neighbors_vector.push_back(7);
      entries_vector.push_back(neighbors_vector.size());
    }

    Array<int> entries   = convert_to_array(entries_vector);
    Array<int> neighbors = convert_to_array(neighbors_vector);

    CRSGraph graph{entries, neighbors};

    Array<int> partitioned = partitioner.partition(graph);

    REQUIRE((partitioned.size() + 1) == entries.size());

    if (parallel::rank() == 0) {
      std::set<int> assigned_ranks;
      for (size_t i = 0; i < partitioned.size(); ++i) {
        assigned_ranks.insert(partitioned[i]);
      }

      REQUIRE(assigned_ranks.size() == parallel::size());
    }
  }
}