#include <Messenger.hpp>
#include <PastisOStream.hpp>

Messenger* Messenger::m_instance = nullptr;

void Messenger::create(int& argc, char* argv[])
{
  if (Messenger::m_instance == nullptr) {
    Messenger::m_instance = new Messenger(argc, argv);
  } else {
    std::cerr << "Messenger already created\n";
    std::exit(1);
  }
}

void Messenger::destroy()
{
  // One allows multiple destruction to handle unexpected code exit
  if (Messenger::m_instance != nullptr) {
    delete Messenger::m_instance;
    Messenger::m_instance = nullptr;
  }
}

Messenger::
Messenger(int& argc, char* argv[])
{
#ifdef PASTIS_HAS_MPI
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &m_rank);
  MPI_Comm_size(MPI_COMM_WORLD, &m_size);

  if (m_rank != 0) {
    pout.setOutput(null_stream);
    perr.setOutput(null_stream);
  }
#endif // PASTIS_HAS_MPI
}

Messenger::
~Messenger()
{
#ifdef PASTIS_HAS_MPI
  MPI_Finalize();
#endif // PASTIS_HAS_MPI
}

void Messenger::barrier() const
{
#ifdef PASTIS_HAS_MPI
  MPI_Barrier(MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
}

Array<int> Messenger::
_allGather(int& data) const
{
  Array<int> gather(m_size);

#ifdef PASTIS_HAS_MPI
  MPI_Allgather(&data, 1,  MPI_INT, &(gather[0]), 1,  MPI_INT,
                MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
  gather[0] = data;
#endif // PASTIS_HAS_MPI

  return gather;
}

Array<int> Messenger::
_allToAll(const Array<int>& sent_array, Array<int>& recv_array) const
{
#ifdef PASTIS_HAS_MPI
  Assert(sent_array.size() == m_size);
  Assert(recv_array.size() == m_size);

  MPI_Alltoall(&(sent_array[0]), 1, MPI_INT,
               &(recv_array[0]), 1, MPI_INT,
               MPI_COMM_WORLD);
#else  // PASTIS_HAS_MPI
  recv_array = copy(sent_array);
#endif // PASTIS_HAS_MPI
  return recv_array;
}


int Messenger::
_broadcast(int& data, int root_rank) const
{
#ifdef PASTIS_HAS_MPI
  MPI_Bcast(&data, 1,  MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
  return data;
}


Array<int> Messenger::
_broadcast(Array<int>& array, int root_rank) const
{
#ifdef PASTIS_HAS_MPI
  int size = array.size();
  _broadcast(size, root_rank);
  if (commRank() != root_rank) {
    array = Array<int>(size);
  }
  MPI_Bcast(&(array[0]), array.size(), MPI_INT, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
  return array;
}
