#ifndef SYNCHRONIZER_HPP
#define SYNCHRONIZER_HPP

#include <mesh/Connectivity.hpp>
#include <mesh/ItemArray.hpp>
#include <mesh/ItemValue.hpp>
#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <iostream>
#include <map>

class Synchronizer
{
  template <ItemType item_type>
  using ExchangeItemTypeInfo = std::vector<Array<const ItemIdT<item_type>>>;

  ExchangeItemTypeInfo<ItemType::cell> m_requested_cell_info;
  ExchangeItemTypeInfo<ItemType::cell> m_provided_cell_info;

  ExchangeItemTypeInfo<ItemType::face> m_requested_face_info;
  ExchangeItemTypeInfo<ItemType::face> m_provided_face_info;

  ExchangeItemTypeInfo<ItemType::edge> m_requested_edge_info;
  ExchangeItemTypeInfo<ItemType::edge> m_provided_edge_info;

  ExchangeItemTypeInfo<ItemType::node> m_requested_node_info;
  ExchangeItemTypeInfo<ItemType::node> m_provided_node_info;

  template <ItemType item_type>
  PUGS_INLINE constexpr auto&
  _getRequestedItemInfo()
  {
    if constexpr (item_type == ItemType::cell) {
      return m_requested_cell_info;
    } else if constexpr (item_type == ItemType::face) {
      return m_requested_face_info;
    } else if constexpr (item_type == ItemType::edge) {
      return m_requested_edge_info;
    } else if constexpr (item_type == ItemType::node) {
      return m_requested_node_info;
    }
  }

  template <ItemType item_type>
  PUGS_INLINE constexpr auto&
  _getProvidedItemInfo()
  {
    if constexpr (item_type == ItemType::cell) {
      return m_provided_cell_info;
    } else if constexpr (item_type == ItemType::face) {
      return m_provided_face_info;
    } else if constexpr (item_type == ItemType::edge) {
      return m_provided_edge_info;
    } else if constexpr (item_type == ItemType::node) {
      return m_provided_node_info;
    }
  }

  template <typename ConnectivityType, ItemType item_type>
  void
  _buildSynchronizeInfo(const ConnectivityType& connectivity)
  {
    const auto& item_owner = connectivity.template owner<item_type>();
    using ItemId           = ItemIdT<item_type>;

    auto& requested_item_info = this->_getRequestedItemInfo<item_type>();
    requested_item_info       = [&]() {
      std::vector<std::vector<ItemId>> requested_item_vector_info(parallel::size());
      for (ItemId item_id = 0; item_id < item_owner.numberOfItems(); ++item_id) {
        if (const size_t owner = item_owner[item_id]; owner != parallel::rank()) {
          requested_item_vector_info[owner].emplace_back(item_id);
        }
      }
      std::vector<Array<const ItemId>> requested_item_info(parallel::size());
      for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
        const auto& requested_item_vector = requested_item_vector_info[i_rank];
        requested_item_info[i_rank]       = convert_to_array(requested_item_vector);
      }
      return requested_item_info;
    }();

    Array<unsigned int> local_number_of_requested_values(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      local_number_of_requested_values[i_rank] = requested_item_info[i_rank].size();
    }

    Array<unsigned int> local_number_of_values_to_send = parallel::allToAll(local_number_of_requested_values);

    std::vector<Array<const int>> requested_item_number_list_by_proc(parallel::size());
    const auto& item_number = connectivity.template number<item_type>();
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      Array<int> item_number_list{requested_item_info_from_rank.size()};
      parallel_for(
        requested_item_info_from_rank.size(),
        PUGS_LAMBDA(size_t i_item) { item_number_list[i_item] = item_number[requested_item_info_from_rank[i_item]]; });
      requested_item_number_list_by_proc[i_rank] = item_number_list;
    }

    std::vector<Array<int>> provided_item_number_list_by_rank(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      provided_item_number_list_by_rank[i_rank] = Array<int>{local_number_of_values_to_send[i_rank]};
    }

    parallel::exchange(requested_item_number_list_by_proc, provided_item_number_list_by_rank);

    std::map<int, ItemId> item_number_to_id_correspondance;
    for (ItemId item_id = 0; item_id < item_number.numberOfItems(); ++item_id) {
      item_number_to_id_correspondance[item_number[item_id]] = item_id;
    }

    auto& provided_item_info = this->_getProvidedItemInfo<item_type>();
    provided_item_info       = [&]() {
      std::vector<Array<const ItemId>> provided_item_info(parallel::size());
      for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
        Array<ItemId> provided_item_id_to_rank{local_number_of_values_to_send[i_rank]};
        const Array<int>& provided_item_number_to_rank = provided_item_number_list_by_rank[i_rank];
        for (size_t i = 0; i < provided_item_number_to_rank.size(); ++i) {
          provided_item_id_to_rank[i] = item_number_to_id_correspondance.find(provided_item_number_to_rank[i])->second;
        }
        provided_item_info[i_rank] = provided_item_id_to_rank;
      }
      return provided_item_info;
    }();
  }

  template <typename ConnectivityType, typename DataType, ItemType item_type, typename ConnectivityPtr>
  PUGS_INLINE void
  _synchronize(const ConnectivityType& connectivity, ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
  {
    static_assert(not std::is_abstract_v<ConnectivityType>, "_synchronize must be called on a concrete connectivity");

    using ItemId = ItemIdT<item_type>;

    const auto& provided_item_info  = this->_getProvidedItemInfo<item_type>();
    const auto& requested_item_info = this->_getRequestedItemInfo<item_type>();

    Assert(requested_item_info.size() == provided_item_info.size());

    if (provided_item_info.size() == 0) {
      this->_buildSynchronizeInfo<ConnectivityType, item_type>(connectivity);
    }

    std::vector<Array<const DataType>> provided_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const Array<const ItemId>& provided_item_info_to_rank = provided_item_info[i_rank];
      Array<DataType> provided_data{provided_item_info_to_rank.size()};
      parallel_for(
        provided_item_info_to_rank.size(),
        PUGS_LAMBDA(size_t i) { provided_data[i] = item_value[provided_item_info_to_rank[i]]; });
      provided_data_list[i_rank] = provided_data;
    }

    std::vector<Array<DataType>> requested_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      requested_data_list[i_rank]               = Array<DataType>{requested_item_info_from_rank.size()};
    }

    parallel::exchange(provided_data_list, requested_data_list);

    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      const auto& requested_data                = requested_data_list[i_rank];
      parallel_for(
        requested_item_info_from_rank.size(),
        PUGS_LAMBDA(size_t i) { item_value[requested_item_info_from_rank[i]] = requested_data[i]; });
    }
  }

  template <typename ConnectivityType, typename DataType, ItemType item_type, typename ConnectivityPtr>
  PUGS_INLINE void
  _synchronize(const ConnectivityType& connectivity, ItemArray<DataType, item_type, ConnectivityPtr>& item_array)
  {
    static_assert(not std::is_abstract_v<ConnectivityType>, "_synchronize must be called on a concrete connectivity");

    using ItemId = ItemIdT<item_type>;

    const auto& provided_item_info  = this->_getProvidedItemInfo<item_type>();
    const auto& requested_item_info = this->_getRequestedItemInfo<item_type>();

    Assert(requested_item_info.size() == provided_item_info.size());

    if (provided_item_info.size() == 0) {
      this->_buildSynchronizeInfo<ConnectivityType, item_type>(connectivity);
    }

    const size_t size_of_arrays = item_array.sizeOfArrays();

    std::vector<Array<const DataType>> provided_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const Array<const ItemId>& provided_item_info_to_rank = provided_item_info[i_rank];
      Array<DataType> provided_data{provided_item_info_to_rank.size() * size_of_arrays};
      parallel_for(
        provided_item_info_to_rank.size(), PUGS_LAMBDA(size_t i) {
          const size_t j   = i * size_of_arrays;
          const auto array = item_array[provided_item_info_to_rank[i]];
          for (size_t k = 0; k < size_of_arrays; ++k) {
            provided_data[j + k] = array[k];
          }
        });
      provided_data_list[i_rank] = provided_data;
    }

    std::vector<Array<DataType>> requested_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      requested_data_list[i_rank] = Array<DataType>{requested_item_info_from_rank.size() * size_of_arrays};
    }

    parallel::exchange(provided_data_list, requested_data_list);

    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      const auto& requested_data                = requested_data_list[i_rank];
      parallel_for(
        requested_item_info_from_rank.size(), PUGS_LAMBDA(size_t i) {
          const size_t j = i * size_of_arrays;
          auto array     = item_array[requested_item_info_from_rank[i]];
          for (size_t k = 0; k < size_of_arrays; ++k) {
            array[k] = requested_data[j + k];
          }
        });
    }
  }

 public:
  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  PUGS_INLINE void
  synchronize(ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
  {
    Assert(item_value.connectivity_ptr().use_count() > 0, "No connectivity is associated to this ItemValue");
    const IConnectivity& connectivity = *item_value.connectivity_ptr();

    switch (connectivity.dimension()) {
    case 1: {
      this->_synchronize(static_cast<const Connectivity1D&>(connectivity), item_value);
      break;
    }
    case 2: {
      this->_synchronize(static_cast<const Connectivity2D&>(connectivity), item_value);
      break;
    }
    case 3: {
      this->_synchronize(static_cast<const Connectivity3D&>(connectivity), item_value);
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("unexpected dimension");
    }
      // LCOV_EXCL_STOP
    }
  }

  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  PUGS_INLINE void
  synchronize(ItemArray<DataType, item_type, ConnectivityPtr>& item_value)
  {
    Assert(item_value.connectivity_ptr().use_count() > 0, "No connectivity is associated to this ItemValue");
    const IConnectivity& connectivity = *item_value.connectivity_ptr();

    switch (connectivity.dimension()) {
    case 1: {
      this->_synchronize(static_cast<const Connectivity1D&>(connectivity), item_value);
      break;
    }
    case 2: {
      this->_synchronize(static_cast<const Connectivity2D&>(connectivity), item_value);
      break;
    }
    case 3: {
      this->_synchronize(static_cast<const Connectivity3D&>(connectivity), item_value);
      break;
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError("unexpected dimension");
    }
      // LCOV_EXCL_STOP
    }
  }

  PUGS_INLINE
  Synchronizer()
  {
    ;
  }
};

#endif   // SYNCHRONIZER_HPP