#ifndef MESSENGER_HPP
#define MESSENGER_HPP

#include <PastisMacros.hpp>
#include <PastisAssert.hpp>

#include <Array.hpp>
#include <CastArray.hpp>
#include <ArrayUtils.hpp>

#include <type_traits>
#include <vector>

#include <pastis_config.hpp>
#ifdef PASTIS_HAS_MPI
#include <mpi.h>
#endif // PASTIS_HAS_MPI

namespace parallel
{

class Messenger
{
 private:
  struct helper
  {
#ifdef PASTIS_HAS_MPI
    template<typename DataType>
    static PASTIS_INLINE
    MPI_Datatype mpiType()
    {
      if constexpr (std::is_const_v<DataType>) {
        return mpiType<std::remove_const_t<DataType>>();
      } else {
        static_assert(std::is_arithmetic_v<DataType>,
                      "Unexpected arithmetic type! Should not occur!");
        static_assert(not std::is_arithmetic_v<DataType>,
                      "MPI_Datatype are only defined for arithmetic types!");
        return MPI_Datatype();
      }
    }
#endif // PASTIS_HAS_MPI

   private:
    template <typename T,
              typename Allowed = void>
    struct split_cast {};

    template <typename T>
    struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int64_t))>> {
      using type = int64_t;
      static_assert(not(sizeof(T) % sizeof(int64_t)));
    };

    template <typename T>
    struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int32_t))
                                         and(sizeof(T) % sizeof(int64_t))>> {
      using type = int32_t;
      static_assert(not(sizeof(T) % sizeof(int32_t)));
    };

    template <typename T>
    struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int16_t))
                                         and(sizeof(T) % sizeof(int32_t))
                                         and(sizeof(T) % sizeof(int64_t))>> {
      using type = int16_t;
      static_assert(not(sizeof(T) % sizeof(int16_t)));
    };

    template <typename T>
    struct split_cast<T,std::enable_if_t<not(sizeof(T) % sizeof(int8_t))
                                         and(sizeof(T) % sizeof(int16_t))
                                         and(sizeof(T) % sizeof(int32_t))
                                         and(sizeof(T) % sizeof(int64_t))>> {
      using type = int8_t;
      static_assert(not(sizeof(T) % sizeof(int8_t)));
    };

   public:
    template <typename T>
    using split_cast_t = typename split_cast<T>::type;
  };

  static Messenger* m_instance;
  Messenger(int& argc, char* argv[]);

  size_t m_rank{0};
  size_t m_size{1};

  template <typename DataType>
  void _allGather(const DataType& data,
                  Array<DataType> gather) const
  {
    static_assert(std::is_arithmetic_v<DataType>);
    Assert(gather.size() == m_size); // LCOV_EXCL_LINE

#ifdef PASTIS_HAS_MPI
    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<DataType>();

    MPI_Allgather(&data, 1,  mpi_datatype,
                  &(gather[0]), 1, mpi_datatype,
                  MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
    gather[0] = data;
#endif // PASTIS_HAS_MPI
  }



  template <template <typename ...SendT> typename SendArrayType,
            template <typename ...RecvT> typename RecvArrayType,
            typename ...SendT, typename ...RecvT>
  void _allGather(const SendArrayType<SendT...>& data_array,
                  RecvArrayType<RecvT...> gather_array) const
  {
    Assert(gather_array.size() == data_array.size()*m_size); // LCOV_EXCL_LINE

    using SendDataType = typename SendArrayType<SendT...>::data_type;
    using RecvDataType = typename RecvArrayType<RecvT...>::data_type;

    static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
    static_assert(std::is_arithmetic_v<SendDataType>);

#ifdef PASTIS_HAS_MPI
    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<RecvDataType>();

    MPI_Allgather(&(data_array[0]), data_array.size(),  mpi_datatype,
                  &(gather_array[0]), data_array.size(),  mpi_datatype,
                  MPI_COMM_WORLD);
#else // PASTIS_HAS_MPI
    value_copy(data_array, gather_array);
#endif // PASTIS_HAS_MPI
  }

  template <typename DataType>
  void _broadcast_value(DataType& data, const size_t& root_rank) const
  {
    static_assert(not std::is_const_v<DataType>);
    static_assert(std::is_arithmetic_v<DataType>);

#ifdef PASTIS_HAS_MPI
    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<DataType>();

    MPI_Bcast(&data, 1,  mpi_datatype, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
  }

  template <typename ArrayType>
  void _broadcast_array(ArrayType& array, const size_t& root_rank) const
  {
    using DataType = typename ArrayType::data_type;
    static_assert(not std::is_const_v<DataType>);
    static_assert(std::is_arithmetic_v<DataType>);

#ifdef PASTIS_HAS_MPI
    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<DataType>();
    MPI_Bcast(&(array[0]), array.size(), mpi_datatype, root_rank, MPI_COMM_WORLD);
#endif // PASTIS_HAS_MPI
  }

  template <template <typename ...SendT> typename SendArrayType,
            template <typename ...RecvT> typename RecvArrayType,
            typename ...SendT, typename ...RecvT>
  RecvArrayType<RecvT...> _allToAll(const SendArrayType<SendT...>& sent_array,
                                    RecvArrayType<RecvT...>& recv_array) const
  {
#ifdef PASTIS_HAS_MPI
    using SendDataType = typename SendArrayType<SendT...>::data_type;
    using RecvDataType = typename RecvArrayType<RecvT...>::data_type;

    static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
    static_assert(std::is_arithmetic_v<SendDataType>);

    Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE
    Assert(sent_array.size() == recv_array.size()); // LCOV_EXCL_LINE

    const size_t count = sent_array.size()/m_size;

    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<SendDataType>();

    MPI_Alltoall(&(sent_array[0]), count, mpi_datatype,
                 &(recv_array[0]), count, mpi_datatype,
                 MPI_COMM_WORLD);
#else  // PASTIS_HAS_MPI
    value_copy(sent_array, recv_array);
#endif // PASTIS_HAS_MPI
    return recv_array;
  }

  template <template <typename ...SendT> typename SendArrayType,
            template <typename ...RecvT> typename RecvArrayType,
            typename ...SendT, typename ...RecvT>
  void _exchange(const std::vector<SendArrayType<SendT...>>& sent_array_list,
                 std::vector<RecvArrayType<RecvT...>>& recv_array_list) const
  {
    using SendDataType = typename SendArrayType<SendT...>::data_type;
    using RecvDataType = typename RecvArrayType<RecvT...>::data_type;

    static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>);
    static_assert(std::is_arithmetic_v<SendDataType>);

#ifdef PASTIS_HAS_MPI
    std::vector<MPI_Request> request_list;

    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<SendDataType>();

    for (size_t i_send=0; i_send<sent_array_list.size(); ++i_send) {
      const auto sent_array = sent_array_list[i_send];
      if (sent_array.size()>0) {
        MPI_Request request;
        MPI_Isend(&(sent_array[0]), sent_array.size(), mpi_datatype, i_send, 0, MPI_COMM_WORLD, &request);
        request_list.push_back(request);
      }
    }

    for (size_t i_recv=0; i_recv<recv_array_list.size(); ++i_recv) {
      auto recv_array = recv_array_list[i_recv];
      if (recv_array.size()>0) {
        MPI_Request request;
        MPI_Irecv(&(recv_array[0]), recv_array.size(), mpi_datatype, i_recv, 0, MPI_COMM_WORLD, &request);
        request_list.push_back(request);
      }
    }

    std::vector<MPI_Status> status_list(request_list.size());
    if (MPI_SUCCESS != MPI_Waitall(request_list.size(), &(request_list[0]), &(status_list[0]))) {
      // LCOV_EXCL_START
      std::cerr << "Communication error!\n";
      std::exit(1);
      // LCOV_EXCL_STOP
    }

#else // PASTIS_HAS_MPI
    Assert(sent_array_list.size() == 1);
    Assert(recv_array_list.size() == 1);

    value_copy(sent_array_list[0], recv_array_list[0]);
#endif // PASTIS_HAS_MPI
  }

  template <typename DataType,
            typename CastDataType>
  void _exchange_through_cast(const std::vector<Array<DataType>>& sent_array_list,
                              std::vector<Array<std::remove_const_t<DataType>>>& recv_array_list) const
  {
    std::vector<CastArray<DataType, const CastDataType>> sent_cast_array_list;
    for (size_t i=0; i<sent_array_list.size(); ++i) {
      sent_cast_array_list.emplace_back(cast_array_to<const CastDataType>::from(sent_array_list[i]));
    }

    using MutableDataType = std::remove_const_t<DataType>;
    std::vector<CastArray<MutableDataType, CastDataType>> recv_cast_array_list;
    for (size_t i=0; i<sent_array_list.size(); ++i) {
      recv_cast_array_list.emplace_back(recv_array_list[i]);
    }

    _exchange(sent_cast_array_list, recv_cast_array_list);
  }

 public:
  static void create(int& argc, char* argv[]);
  static void destroy();

  PASTIS_INLINE
  static Messenger& getInstance()
  {
    Assert(m_instance != nullptr); // LCOV_EXCL_LINE
    return *m_instance;
  }

  PASTIS_INLINE
  const size_t& rank() const
  {
    return m_rank;
  }

  PASTIS_INLINE
  const size_t& size() const
  {
    return m_size;
  }

  void barrier() const;

  template <typename DataType>
  DataType allReduceMin(const DataType& data) const
  {
#ifdef PASTIS_HAS_MPI
    static_assert(not std::is_const_v<DataType>);
    static_assert(std::is_arithmetic_v<DataType>);

    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<DataType>();

    DataType min_data = data;
    MPI_Allreduce(&data, &min_data, 1, mpi_datatype, MPI_MIN, MPI_COMM_WORLD);

    return min_data;
#else // PASTIS_HAS_MPI
    return data;
#endif // PASTIS_HAS_MPI
  }

  template <typename DataType>
  DataType allReduceMax(const DataType& data) const
  {
#ifdef PASTIS_HAS_MPI
    static_assert(not std::is_const_v<DataType>);
    static_assert(std::is_arithmetic_v<DataType>);

    MPI_Datatype mpi_datatype
        = Messenger::helper::mpiType<DataType>();

    DataType max_data = data;
    MPI_Allreduce(&data, &max_data, 1, mpi_datatype, MPI_MAX, MPI_COMM_WORLD);

    return max_data;
#else // PASTIS_HAS_MPI
    return data;
#endif // PASTIS_HAS_MPI
  }

  template <typename DataType>
  PASTIS_INLINE
  Array<DataType>
  allGather(const DataType& data) const
  {
    static_assert(not std::is_const_v<DataType>);

    Array<DataType> gather_array(m_size);

    if constexpr(std::is_arithmetic_v<DataType>) {
      _allGather(data, gather_array);
    } else  if constexpr(std::is_trivial_v<DataType>) {
      using CastType = helper::split_cast_t<DataType>;

      CastArray cast_value_array = cast_value_to<const CastType>::from(data);
      CastArray cast_gather_array = cast_array_to<CastType>::from(gather_array);

      _allGather(cast_value_array, cast_gather_array);
    } else {
      static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
    }
    return gather_array;
  }

  template <typename DataType>
  PASTIS_INLINE
  Array<std::remove_const_t<DataType>>
  allGather(const Array<DataType>& array) const
  {
    using MutableDataType = std::remove_const_t<DataType>;
    Array<MutableDataType> gather_array(m_size*array.size());

    if constexpr(std::is_arithmetic_v<DataType>) {
      _allGather(array, gather_array);
    } else  if constexpr(std::is_trivial_v<DataType>) {
      using CastType = helper::split_cast_t<DataType>;
      using MutableCastType = helper::split_cast_t<MutableDataType>;

      CastArray cast_array = cast_array_to<CastType>::from(array);
      CastArray cast_gather_array = cast_array_to<MutableCastType>::from(gather_array);

      _allGather(cast_array, cast_gather_array);
    } else {
      static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
    }
    return gather_array;
  }

  template <typename SendDataType>
  PASTIS_INLINE
  Array<std::remove_const_t<SendDataType>>
  allToAll(const Array<SendDataType>& sent_array) const
  {
#ifndef NDEBUG
    const size_t min_size = allReduceMin(sent_array.size());
    const size_t max_size = allReduceMax(sent_array.size());
    Assert(max_size == min_size); // LCOV_EXCL_LINE
#endif // NDEBUG
    Assert((sent_array.size() % m_size) == 0); // LCOV_EXCL_LINE

    using DataType = std::remove_const_t<SendDataType>;
    Array<DataType> recv_array(sent_array.size());

    if constexpr(std::is_arithmetic_v<DataType>) {
      _allToAll(sent_array, recv_array);
    } else if constexpr(std::is_trivial_v<DataType>) {
      using CastType = helper::split_cast_t<DataType>;

      auto send_cast_array = cast_array_to<const CastType>::from(sent_array);
      auto recv_cast_array = cast_array_to<CastType>::from(recv_array);
      _allToAll(send_cast_array, recv_cast_array);
    } else {
      static_assert(std::is_trivial_v<DataType>, "unexpected type of data");
    }
    return recv_array;
  }

  template <typename DataType>
  PASTIS_INLINE
  void broadcast(DataType& data, const size_t& root_rank) const
  {
    static_assert(not std::is_const_v<DataType>,
                  "cannot broadcast const data");
    if constexpr(std::is_arithmetic_v<DataType>) {
      _broadcast_value(data, root_rank);
    } else if constexpr(std::is_trivial_v<DataType>) {
      using CastType = helper::split_cast_t<DataType>;
      if constexpr(sizeof(CastType) == sizeof(DataType)) {
        CastType& cast_data = reinterpret_cast<CastType&>(data);
        _broadcast_value(cast_data, root_rank);
      } else {
        CastArray cast_array = cast_value_to<CastType>::from(data);
        _broadcast_array(cast_array, root_rank);
      }
    } else {
      static_assert(std::is_trivial_v<DataType>,
                    "unexpected non trivial type of data");
    }
  }

  template <typename DataType>
  PASTIS_INLINE
  void broadcast(Array<DataType>& array,
                 const size_t& root_rank) const
  {
    static_assert(not std::is_const_v<DataType>,
                  "cannot broadcast array of const");
    if constexpr(std::is_arithmetic_v<DataType>) {
      size_t size = array.size();
      _broadcast_value(size, root_rank);
      if (m_rank != root_rank) {
        array = Array<DataType>(size); // LCOV_EXCL_LINE
      }
      _broadcast_array(array, root_rank);
    } else if constexpr(std::is_trivial_v<DataType>) {
      size_t size = array.size();
      _broadcast_value(size, root_rank);
      if (m_rank != root_rank) {
        array = Array<DataType>(size); // LCOV_EXCL_LINE
      }

      using CastType = helper::split_cast_t<DataType>;
      auto cast_array = cast_array_to<CastType>::from(array);
      _broadcast_array(cast_array, root_rank);
    } else{
      static_assert(std::is_trivial_v<DataType>,
                    "unexpected non trivial type of data");
    }
  }

  template <typename SendDataType,
            typename RecvDataType>
  PASTIS_INLINE
  void exchange(const std::vector<Array<SendDataType>>& send_array_list,
                std::vector<Array<RecvDataType>>& recv_array_list) const
  {
    static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
                  "send and receive data type do not match");
    static_assert(not std::is_const_v<RecvDataType>,
                  "receive data type cannot be const");
    using DataType = std::remove_const_t<SendDataType>;

    Assert(send_array_list.size() == m_size); // LCOV_EXCL_LINE
    Assert(recv_array_list.size() == m_size); // LCOV_EXCL_LINE
#ifndef NDEBUG
    Array<size_t> send_size(m_size);
    for (size_t i=0; i<m_size; ++i) {
      send_size[i] = send_array_list[i].size();
    }
    Array<size_t> recv_size = allToAll(send_size);
    bool correct_sizes = true;
    for (size_t i=0; i<m_size; ++i) {
      correct_sizes &= (recv_size[i] == recv_array_list[i].size());
    }
    Assert(correct_sizes); // LCOV_EXCL_LINE
#endif // NDEBUG

    if constexpr(std::is_arithmetic_v<DataType>) {
      _exchange(send_array_list, recv_array_list);
    } else if constexpr(std::is_trivial_v<DataType>) {
      using CastType = helper::split_cast_t<DataType>;
      _exchange_through_cast<SendDataType, CastType>(send_array_list, recv_array_list);
    } else {
      static_assert(std::is_trivial_v<RecvDataType>,
                    "unexpected non trivial type of data");
    }
  }

  Messenger(const Messenger&) = delete;
  ~Messenger();
};

PASTIS_INLINE
const Messenger& messenger()
{
  return Messenger::getInstance();
}

PASTIS_INLINE
const size_t& rank()
{
  return messenger().rank();
}

PASTIS_INLINE
const size_t& size()
{
  return messenger().size();
}

PASTIS_INLINE
void barrier()
{
  return messenger().barrier();
}

template <typename DataType>
PASTIS_INLINE
DataType allReduceMax(const DataType& data)
{
  return messenger().allReduceMax(data);
}

template <typename DataType>
PASTIS_INLINE
DataType allReduceMin(const DataType& data)
{
  return messenger().allReduceMin(data);
}

template <typename DataType>
PASTIS_INLINE
Array<DataType>
allGather(const DataType& data)
{
  return messenger().allGather(data);
}

template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allGather(const Array<DataType>& array)
{
  return messenger().allGather(array);
}

template <typename DataType>
PASTIS_INLINE
Array<std::remove_const_t<DataType>>
allToAll(const Array<DataType>& array)
{
  return messenger().allToAll(array);
}

template <typename DataType>
PASTIS_INLINE
void broadcast(DataType& data, const size_t& root_rank)
{
  messenger().broadcast(data, root_rank);
}

template <typename DataType>
PASTIS_INLINE
void broadcast(Array<DataType>& array, const size_t& root_rank)
{
  messenger().broadcast(array, root_rank);
}

template <typename SendDataType,
          typename RecvDataType>
PASTIS_INLINE
void exchange(const std::vector<Array<SendDataType>>& sent_array_list,
              std::vector<Array<RecvDataType>>& recv_array_list)
{
  static_assert(std::is_same_v<std::remove_const_t<SendDataType>,RecvDataType>,
                "send and receive data type do not match");
  static_assert(not std::is_const_v<RecvDataType>,
                "receive data type cannot be const");

   messenger().exchange(sent_array_list, recv_array_list);
}

#ifdef PASTIS_HAS_MPI

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<char>() {return MPI_CHAR; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int8_t>() {return MPI_INT8_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int16_t>() {return MPI_INT16_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int32_t>() {return MPI_INT32_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<int64_t>() {return MPI_INT64_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint8_t>() {return MPI_UINT8_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint16_t>() {return MPI_UINT16_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint32_t>() {return MPI_UINT32_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<uint64_t>() {return MPI_UINT64_T; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<signed long long int>() {return MPI_LONG_LONG_INT; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<unsigned long long int>() {return MPI_UNSIGNED_LONG_LONG; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<float>() {return MPI_FLOAT; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<double>() {return MPI_DOUBLE; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<long double>() {return MPI_LONG_DOUBLE; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<wchar_t>() {return MPI_WCHAR; }

template<> PASTIS_INLINE MPI_Datatype
Messenger::helper::mpiType<bool>() {return MPI_CXX_BOOL; }

#endif // PASTIS_HAS_MPI

} // namespace parallel

#endif // MESSENGER_HPP