#include <mesh/Connectivity.hpp>

#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <utils/Messenger.hpp>

#include <map>

template <size_t Dimension>
Connectivity<Dimension>::Connectivity() = default;

template <size_t Dimension>
void
Connectivity<Dimension>::_buildFrom(const ConnectivityDescriptor& descriptor)
{
  Assert(descriptor.cell_to_node_vector.size() == descriptor.cell_type_vector.size());
  Assert(descriptor.cell_number_vector.size() == descriptor.cell_type_vector.size());
  if constexpr (Dimension > 1) {
    Assert(descriptor.cell_to_face_vector.size() == descriptor.cell_type_vector.size());
    Assert(descriptor.face_to_node_vector.size() == descriptor.face_number_vector.size());
    Assert(descriptor.face_owner_vector.size() == descriptor.face_number_vector.size());
  }

  m_number_of_cells = descriptor.cell_number_vector.size();
  m_number_of_nodes = descriptor.node_number_vector.size();

  if constexpr (Dimension == 1) {
    m_number_of_edges = m_number_of_nodes;
    m_number_of_faces = m_number_of_nodes;
  } else {
    m_number_of_faces = descriptor.face_number_vector.size();
    if constexpr (Dimension == 2) {
      m_number_of_edges = m_number_of_faces;
    } else {
      static_assert(Dimension == 3, "unexpected dimension");
      m_number_of_edges = descriptor.edge_number_vector.size();
    }
  }

  auto& cell_to_node_matrix = m_item_to_item_matrix[itemTId(ItemType::cell)][itemTId(ItemType::node)];
  cell_to_node_matrix       = descriptor.cell_to_node_vector;

  {
    WeakCellValue<CellType> cell_type(*this);
    parallel_for(
      this->numberOfCells(), PUGS_LAMBDA(CellId j) { cell_type[j] = descriptor.cell_type_vector[j]; });
    m_cell_type = cell_type;
  }

  m_cell_number = WeakCellValue<int>(*this, convert_to_array(descriptor.cell_number_vector));

  Array node_number_array = convert_to_array(descriptor.node_number_vector);
  m_node_number           = WeakNodeValue<int>(*this, node_number_array);

  {
    WeakCellValue<int> cell_global_index(*this);
    int first_index = 0;
    parallel_for(
      this->numberOfCells(), PUGS_LAMBDA(CellId j) { cell_global_index[j] = first_index + j; });
    m_cell_global_index = cell_global_index;
  }

  m_cell_owner = WeakCellValue<int>(*this, convert_to_array(descriptor.cell_owner_vector));

  {
    const int rank = parallel::rank();
    WeakCellValue<bool> cell_is_owned(*this);
    parallel_for(
      this->numberOfCells(), PUGS_LAMBDA(CellId j) { cell_is_owned[j] = (m_cell_owner[j] == rank); });
    m_cell_is_owned = cell_is_owned;
  }

  Array node_owner_array = convert_to_array(descriptor.node_owner_vector);
  m_node_owner           = WeakNodeValue<int>{*this, node_owner_array};

  Array<bool> node_is_owned_array(this->numberOfNodes());
  {
    const int rank = parallel::rank();
    WeakNodeValue<bool> node_is_owned(*this, node_is_owned_array);
    parallel_for(
      this->numberOfNodes(), PUGS_LAMBDA(NodeId r) { node_is_owned[r] = (m_node_owner[r] == rank); });
    m_node_is_owned = node_is_owned;
  }

  m_ref_node_list_vector = descriptor.template refItemListVector<ItemType::node>();
  m_ref_cell_list_vector = descriptor.template refItemListVector<ItemType::cell>();

  if constexpr (Dimension == 1) {
    // faces are similar to nodes
    m_face_number   = WeakFaceValue<int>(*this, node_number_array);
    m_face_owner    = WeakFaceValue<int>(*this, node_owner_array);
    m_face_is_owned = WeakFaceValue<bool>(*this, node_is_owned_array);

    // edges are similar to nodes
    m_edge_number   = WeakEdgeValue<int>(*this, node_number_array);
    m_edge_owner    = WeakEdgeValue<int>(*this, node_owner_array);
    m_edge_is_owned = WeakEdgeValue<bool>(*this, node_is_owned_array);

    // edge and face references are set equal to node references
    m_ref_edge_list_vector.reserve(descriptor.template refItemListVector<ItemType::node>().size());
    m_ref_face_list_vector.reserve(descriptor.template refItemListVector<ItemType::node>().size());
    for (const auto& ref_node_list : descriptor.template refItemListVector<ItemType::node>()) {
      const RefId ref_id            = ref_node_list.refId();
      Array<const NodeId> node_list = ref_node_list.list();
      Array<EdgeId> edge_list(node_list.size());
      Array<FaceId> face_list(node_list.size());
      for (size_t i = 0; i < node_list.size(); ++i) {
        edge_list[i] = EdgeId::base_type{node_list[i]};
        face_list[i] = FaceId::base_type{node_list[i]};
      }

      m_ref_edge_list_vector.emplace_back(RefItemList<ItemType::edge>(ref_id, edge_list, ref_node_list.isBoundary()));
      m_ref_face_list_vector.emplace_back(RefItemList<ItemType::face>(ref_id, face_list, ref_node_list.isBoundary()));
    }

  } else {
    m_item_to_item_matrix[itemTId(ItemType::face)][itemTId(ItemType::node)] = descriptor.face_to_node_vector;

    m_item_to_item_matrix[itemTId(ItemType::cell)][itemTId(ItemType::face)] = descriptor.cell_to_face_vector;

    {
      FaceValuePerCell<bool> cell_face_is_reversed(*this);
      for (CellId j = 0; j < descriptor.cell_face_is_reversed_vector.size(); ++j) {
        const auto& face_cells_vector = descriptor.cell_face_is_reversed_vector[j];
        for (unsigned short lj = 0; lj < face_cells_vector.size(); ++lj) {
          cell_face_is_reversed(j, lj) = face_cells_vector[lj];
        }
      }
      m_cell_face_is_reversed = cell_face_is_reversed;
    }

    Array face_number_array = convert_to_array(descriptor.face_number_vector);
    m_face_number           = WeakFaceValue<int>(*this, face_number_array);

    Array face_owner_array = convert_to_array(descriptor.face_owner_vector);
    m_face_owner           = WeakFaceValue<int>(*this, face_owner_array);

    Array<bool> face_is_owned_array(this->numberOfFaces());
    {
      const int rank = parallel::rank();
      WeakFaceValue<bool> face_is_owned(*this, face_is_owned_array);
      parallel_for(
        this->numberOfFaces(), PUGS_LAMBDA(FaceId l) { face_is_owned[l] = (m_face_owner[l] == rank); });
      m_face_is_owned = face_is_owned;
    }

    m_ref_face_list_vector = descriptor.template refItemListVector<ItemType::face>();

    if constexpr (Dimension == 2) {
      // edges are similar to faces
      m_edge_number   = WeakEdgeValue<int>(*this, face_number_array);
      m_edge_owner    = WeakEdgeValue<int>(*this, face_owner_array);
      m_edge_is_owned = WeakEdgeValue<bool>(*this, face_is_owned_array);

      // edge references are set equal to face references
      m_ref_edge_list_vector.reserve(descriptor.template refItemListVector<ItemType::face>().size());
      for (const auto& ref_face_list : descriptor.template refItemListVector<ItemType::face>()) {
        const RefId ref_id            = ref_face_list.refId();
        Array<const FaceId> face_list = ref_face_list.list();
        Array<EdgeId> edge_list(face_list.size());
        for (size_t i = 0; i < face_list.size(); ++i) {
          edge_list[i] = EdgeId::base_type{face_list[i]};
        }

        m_ref_edge_list_vector.emplace_back(RefItemList<ItemType::edge>(ref_id, edge_list, ref_face_list.isBoundary()));
      }

    } else {
      m_item_to_item_matrix[itemTId(ItemType::edge)][itemTId(ItemType::node)] = descriptor.edge_to_node_vector;

      m_item_to_item_matrix[itemTId(ItemType::face)][itemTId(ItemType::edge)] = descriptor.face_to_edge_vector;

      m_item_to_item_matrix[itemTId(ItemType::cell)][itemTId(ItemType::edge)] = descriptor.cell_to_edge_vector;

      {
        EdgeValuePerFace<bool> face_edge_is_reversed(*this);
        for (FaceId l = 0; l < descriptor.face_edge_is_reversed_vector.size(); ++l) {
          const auto& edge_faces_vector = descriptor.face_edge_is_reversed_vector[l];
          for (unsigned short el = 0; el < edge_faces_vector.size(); ++el) {
            face_edge_is_reversed(l, el) = edge_faces_vector[el];
          }
        }
        m_face_edge_is_reversed = face_edge_is_reversed;
      }

      m_edge_number = WeakEdgeValue<int>(*this, convert_to_array(descriptor.edge_number_vector));
      m_edge_owner  = WeakEdgeValue<int>(*this, convert_to_array(descriptor.edge_owner_vector));

      {
        const int rank = parallel::rank();
        WeakEdgeValue<bool> edge_is_owned(*this);
        parallel_for(
          this->numberOfEdges(), PUGS_LAMBDA(EdgeId e) { edge_is_owned[e] = (m_edge_owner[e] == rank); });
        m_edge_is_owned = edge_is_owned;
      }

      m_ref_edge_list_vector = descriptor.template refItemListVector<ItemType::edge>();
    }
  }
}

template <size_t Dimension>
void
Connectivity<Dimension>::_buildIsBoundaryFace() const
{
  Array<bool> is_face_boundary_array(this->numberOfFaces());
  WeakFaceValue<bool> is_boundary_face(*this, is_face_boundary_array);
  const auto& face_to_cell_matrix = this->faceToCellMatrix();
  const auto& face_is_owned       = this->faceIsOwned();
  parallel_for(
    this->numberOfFaces(), PUGS_LAMBDA(const FaceId face_id) {
      is_boundary_face[face_id] = face_is_owned[face_id] and (face_to_cell_matrix[face_id].size() == 1);
    });
  synchronize(is_boundary_face);
  const_cast<WeakFaceValue<const bool>&>(m_is_boundary_face) = is_boundary_face;

  if constexpr (Dimension <= 2) {
    const_cast<WeakEdgeValue<const bool>&>(m_is_boundary_edge) = WeakEdgeValue<bool>(*this, is_face_boundary_array);
    if constexpr (Dimension == 1) {
      const_cast<WeakNodeValue<const bool>&>(m_is_boundary_node) = WeakNodeValue<bool>(*this, is_face_boundary_array);
    } else {
      static_assert(Dimension == 2, "unexpected dimension");
    }
  }
}

template <size_t Dimension>
void
Connectivity<Dimension>::_buildIsBoundaryEdge() const
{
  if constexpr (Dimension < 3) {
    this->_buildIsBoundaryFace();
  } else {
    auto is_boundary_face = this->isBoundaryFace();
    WeakEdgeValue<bool> is_boundary_edge(*this);
    is_boundary_edge.fill(false);
    const auto& face_to_edge_matrix = this->faceToEdgeMatrix();
    for (FaceId face_id = 0; face_id < this->numberOfFaces(); ++face_id) {
      if (is_boundary_face[face_id]) {
        auto face_edge = face_to_edge_matrix[face_id];
        for (size_t i_edge = 0; i_edge < face_edge.size(); ++i_edge) {
          is_boundary_edge[face_edge[i_edge]] = true;
        }
      }
    }
    synchronize(is_boundary_edge);
    const_cast<WeakEdgeValue<const bool>&>(m_is_boundary_edge) = is_boundary_edge;
  }
}

template <size_t Dimension>
void
Connectivity<Dimension>::_buildIsBoundaryNode() const
{
  if constexpr (Dimension == 1) {
    this->_buildIsBoundaryFace();
  } else {
    auto is_boundary_face = this->isBoundaryFace();
    WeakNodeValue<bool> is_boundary_node(*this);
    is_boundary_node.fill(false);
    const auto& face_to_node_matrix = this->faceToNodeMatrix();
    for (FaceId face_id = 0; face_id < this->numberOfFaces(); ++face_id) {
      if (is_boundary_face[face_id]) {
        auto face_nodes = face_to_node_matrix[face_id];
        for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
          is_boundary_node[face_nodes[i_node]] = true;
        }
      }
    }
    synchronize(is_boundary_node);
    const_cast<WeakNodeValue<const bool>&>(m_is_boundary_node) = is_boundary_node;
  }
}

template <ItemType item_type, size_t Dimension>
inline void
_printReference(std::ostream& os, const Connectivity<Dimension>& connectivity, std::set<std::string>& already_printed)
{
  auto count_all_items = [](const auto& item_is_owned) -> size_t {
    using ItemId  = typename std::decay_t<decltype(item_is_owned)>::index_type;
    size_t number = 0;
    for (ItemId item_id = 0; item_id < item_is_owned.numberOfItems(); ++item_id) {
      number += item_is_owned[item_id];
    }
    return parallel::allReduceSum(number);
  };

  auto count_zone_items = [](const auto& item_is_owned, const auto& item_list) -> size_t {
    size_t number = 0;
    for (size_t i_item = 0; i_item < item_list.size(); ++i_item) {
      number += item_is_owned[item_list[i_item]];
    }
    return parallel::allReduceSum(number);
  };

  os << "- number of " << itemName(item_type) << "s: " << rang::fgB::yellow
     << count_all_items(connectivity.template isOwned<item_type>()) << rang::style::reset << '\n';

  // This is done to avoid printing deduced references on subitems
  std::vector<size_t> to_print_list;
  for (size_t i_ref_item = 0; i_ref_item < connectivity.template numberOfRefItemList<item_type>(); ++i_ref_item) {
    auto ref_item_list = connectivity.template refItemList<item_type>(i_ref_item);
    if (already_printed.find(ref_item_list.refId().tagName()) == already_printed.end()) {
      to_print_list.push_back(i_ref_item);
      already_printed.insert(ref_item_list.refId().tagName());
    }
  }

  os << "  " << rang::fgB::yellow << to_print_list.size() << rang::style::reset << " references\n";
  if (to_print_list.size() > 0) {
    for (size_t i_ref_item : to_print_list) {
      auto ref_item_list = connectivity.template refItemList<item_type>(i_ref_item);
      os << "  - " << rang::fgB::green << ref_item_list.refId().tagName() << rang::style::reset << " ("
         << rang::fgB::green << ref_item_list.refId().tagNumber() << rang::style::reset << ") number "
         << rang::fgB::yellow << count_zone_items(connectivity.template isOwned<item_type>(), ref_item_list.list())
         << rang::style::reset << '\n';
    }
  }
}

template <size_t Dimension>
std::ostream&
Connectivity<Dimension>::_write(std::ostream& os) const
{
  std::set<std::string> already_printed;

  os << "connectivity of dimension " << Dimension << '\n';
  _printReference<ItemType::cell>(os, *this, already_printed);
  if constexpr (Dimension > 1) {
    _printReference<ItemType::face>(os, *this, already_printed);
  }
  if constexpr (Dimension > 2) {
    _printReference<ItemType::edge>(os, *this, already_printed);
  }
  _printReference<ItemType::node>(os, *this, already_printed);

  return os;
}

template std::ostream& Connectivity<1>::_write(std::ostream&) const;
template std::ostream& Connectivity<2>::_write(std::ostream&) const;
template std::ostream& Connectivity<3>::_write(std::ostream&) const;

template void Connectivity<1>::_buildIsBoundaryFace() const;
template void Connectivity<2>::_buildIsBoundaryFace() const;
template void Connectivity<3>::_buildIsBoundaryFace() const;

template void Connectivity<1>::_buildIsBoundaryEdge() const;
template void Connectivity<2>::_buildIsBoundaryEdge() const;
template void Connectivity<3>::_buildIsBoundaryEdge() const;

template void Connectivity<1>::_buildIsBoundaryNode() const;
template void Connectivity<2>::_buildIsBoundaryNode() const;
template void Connectivity<3>::_buildIsBoundaryNode() const;

template Connectivity<1>::Connectivity();
template Connectivity<2>::Connectivity();
template Connectivity<3>::Connectivity();

template void Connectivity<1>::_buildFrom(const ConnectivityDescriptor&);
template void Connectivity<2>::_buildFrom(const ConnectivityDescriptor&);
template void Connectivity<3>::_buildFrom(const ConnectivityDescriptor&);
