#ifndef SYNCHRONIZER_HPP
#define SYNCHRONIZER_HPP

#include <Connectivity.hpp>
#include <ItemValue.hpp>

#include <map>

class Synchronizer
{
  template <ItemType item_type>
  using ExchangeItemTypeInfo = std::vector<Array<const ItemIdT<item_type>>>;

  ExchangeItemTypeInfo<ItemType::cell> m_requested_cell_info;
  ExchangeItemTypeInfo<ItemType::cell> m_provided_cell_info;

  ExchangeItemTypeInfo<ItemType::face> m_requested_face_info;
  ExchangeItemTypeInfo<ItemType::face> m_provided_face_info;

  ExchangeItemTypeInfo<ItemType::edge> m_requested_edge_info;
  ExchangeItemTypeInfo<ItemType::edge> m_provided_edge_info;

  ExchangeItemTypeInfo<ItemType::node> m_requested_node_info;
  ExchangeItemTypeInfo<ItemType::node> m_provided_node_info;

  template <ItemType item_type>
  PUGS_INLINE constexpr auto&
  _getRequestedItemInfo()
  {
    if constexpr (item_type == ItemType::cell) {
      return m_requested_cell_info;
    } else if constexpr (item_type == ItemType::face) {
      return m_requested_face_info;
    } else if constexpr (item_type == ItemType::edge) {
      return m_requested_edge_info;
    } else if constexpr (item_type == ItemType::node) {
      return m_requested_node_info;
    }
  }

  template <ItemType item_type>
  PUGS_INLINE constexpr auto&
  _getProvidedItemInfo()
  {
    if constexpr (item_type == ItemType::cell) {
      return m_provided_cell_info;
    } else if constexpr (item_type == ItemType::face) {
      return m_provided_face_info;
    } else if constexpr (item_type == ItemType::edge) {
      return m_provided_edge_info;
    } else if constexpr (item_type == ItemType::node) {
      return m_provided_node_info;
    }
  }

  template <typename ConnectivityType, ItemType item_type>
  void
  _buildSynchronizeInfo(const ConnectivityType& connectivity)
  {
    const auto& item_owner = connectivity.template owner<item_type>();
    using ItemId           = ItemIdT<item_type>;

    auto& requested_item_info = this->_getRequestedItemInfo<item_type>();
    requested_item_info       = [&]() {
      std::vector<std::vector<ItemId>> requested_item_vector_info(
        parallel::size());
      for (ItemId item_id = 0; item_id < item_owner.size(); ++item_id) {
        if (const size_t owner = item_owner[item_id];
            owner != parallel::rank()) {
          requested_item_vector_info[owner].emplace_back(item_id);
        }
      }
      std::vector<Array<const ItemId>> requested_item_info(parallel::size());
      for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
        const auto& requested_item_vector = requested_item_vector_info[i_rank];
        requested_item_info[i_rank] = convert_to_array(requested_item_vector);
      }
      return requested_item_info;
    }();

    Array<unsigned int> local_number_of_requested_values(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      local_number_of_requested_values[i_rank] =
        requested_item_info[i_rank].size();
    }

    Array<unsigned int> local_number_of_values_to_send =
      parallel::allToAll(local_number_of_requested_values);

    std::vector<Array<const int>> requested_item_number_list_by_proc(
      parallel::size());
    const auto& item_number = connectivity.template number<item_type>();
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      Array<int> item_number_list{requested_item_info_from_rank.size()};
      parallel_for(requested_item_info_from_rank.size(),
                   PUGS_LAMBDA(size_t i_item) {
                     item_number_list[i_item] =
                       item_number[requested_item_info_from_rank[i_item]];
                   });
      requested_item_number_list_by_proc[i_rank] = item_number_list;
    }

    std::vector<Array<int>> provided_item_number_list_by_rank(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      provided_item_number_list_by_rank[i_rank] =
        Array<int>{local_number_of_values_to_send[i_rank]};
    }

    parallel::exchange(requested_item_number_list_by_proc,
                       provided_item_number_list_by_rank);

    std::map<int, ItemId> item_number_to_id_correspondance;
    for (ItemId item_id = 0; item_id < item_number.size(); ++item_id) {
      item_number_to_id_correspondance[item_number[item_id]] = item_id;
    }

    auto& provided_item_info = this->_getProvidedItemInfo<item_type>();
    provided_item_info       = [&]() {
      std::vector<Array<const ItemId>> provided_item_info(parallel::size());
      for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
        Array<ItemId> provided_item_id_to_rank{
          local_number_of_values_to_send[i_rank]};
        const Array<int>& provided_item_number_to_rank =
          provided_item_number_list_by_rank[i_rank];
        for (size_t i = 0; i < provided_item_number_to_rank.size(); ++i) {
          provided_item_id_to_rank[i] = item_number_to_id_correspondance
                                          .find(provided_item_number_to_rank[i])
                                          ->second;
        }
        provided_item_info[i_rank] = provided_item_id_to_rank;
      }
      return provided_item_info;
    }();
  }

  template <typename ConnectivityType,
            typename DataType,
            ItemType item_type,
            typename ConnectivityPtr>
  PUGS_INLINE void
  _synchronize(const ConnectivityType& connectivity,
               ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
  {
    static_assert(not std::is_abstract_v<ConnectivityType>,
                  "_synchronize must be called on a concrete connectivity");

    using ItemId = ItemIdT<item_type>;

    const auto& provided_item_info  = this->_getProvidedItemInfo<item_type>();
    const auto& requested_item_info = this->_getRequestedItemInfo<item_type>();

    Assert(requested_item_info.size() == provided_item_info.size());

    if (provided_item_info.size() == 0) {
      this->_buildSynchronizeInfo<ConnectivityType, item_type>(connectivity);
    }

    std::vector<Array<const DataType>> provided_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const Array<const ItemId>& provided_item_info_to_rank =
        provided_item_info[i_rank];
      Array<DataType> provided_data{provided_item_info_to_rank.size()};
      parallel_for(provided_item_info_to_rank.size(), PUGS_LAMBDA(size_t i) {
        provided_data[i] = item_value[provided_item_info_to_rank[i]];
      });
      provided_data_list[i_rank] = provided_data;
    }

    std::vector<Array<DataType>> requested_data_list(parallel::size());
    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      requested_data_list[i_rank] =
        Array<DataType>{requested_item_info_from_rank.size()};
    }

    parallel::exchange(provided_data_list, requested_data_list);

    for (size_t i_rank = 0; i_rank < parallel::size(); ++i_rank) {
      const auto& requested_item_info_from_rank = requested_item_info[i_rank];
      const auto& requested_data                = requested_data_list[i_rank];
      parallel_for(requested_item_info_from_rank.size(), PUGS_LAMBDA(size_t i) {
        item_value[requested_item_info_from_rank[i]] = requested_data[i];
      });
    }
  }

 public:
  template <typename DataType, ItemType item_type, typename ConnectivityPtr>
  PUGS_INLINE void
  synchronize(ItemValue<DataType, item_type, ConnectivityPtr>& item_value)
  {
    Assert(item_value.connectivity_ptr().use_count() > 0,
           "No connectivity is associated to this ItemValue");
    const IConnectivity& connectivity = *item_value.connectivity_ptr();

    switch (connectivity.dimension()) {
      case 1: {
        this->_synchronize(static_cast<const Connectivity1D&>(connectivity),
                           item_value);
        break;
      }
      case 2: {
        this->_synchronize(static_cast<const Connectivity2D&>(connectivity),
                           item_value);
        break;
      }
      case 3: {
        this->_synchronize(static_cast<const Connectivity3D&>(connectivity),
                           item_value);
        break;
      }
      default: {
        perr() << __FILE__ << ':' << __LINE__ << ": unexpected dimension\n";
        std::terminate();
      }
    }
  }

  PUGS_INLINE
  Synchronizer()
  {
    ;
  }
};

#endif   // SYNCHRONIZER_HPP