#include <Connectivity.hpp>
#include <map>

template<>
void Connectivity<3>::_computeFaceCellConnectivities()
{
  Kokkos::View<unsigned short*> cell_nb_faces("cell_nb_faces", this->numberOfCells());

  typedef std::tuple<unsigned int, unsigned short, bool> CellFaceInfo;

  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);

    switch (cell_nodes.length) {
      case 4: { // tetrahedron
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(3)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(0),
                 cell_nodes(3),
                 cell_nodes(2)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(0),
                 cell_nodes(2),
                 cell_nodes(1)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      case 8: { // hexahedron
        // face 0
        Face f0({cell_nodes(3),
                 cell_nodes(2),
                 cell_nodes(1),
                 cell_nodes(0)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(4),
                 cell_nodes(5),
                 cell_nodes(6),
                 cell_nodes(7)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(4),
                 cell_nodes(7),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(6),
                 cell_nodes(5)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));

        // face 4
        Face f4({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(5),
                 cell_nodes(4)});
        face_cells_map[f4].emplace_back(std::make_tuple(j, 4, f4.reversed()));

        // face 5
        Face f5({cell_nodes(3),
                 cell_nodes(7),
                 cell_nodes(6),
                 cell_nodes(2)});
        face_cells_map[f5].emplace_back(std::make_tuple(j, 5, f5.reversed()));

        cell_nb_faces[j] = 6;
        break;
      }
      default: {
        std::cerr << "unexpected cell type!\n";
        std::exit(0);
      }
    }
  }

  {
    auto& cell_to_face_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
    std::vector<std::vector<unsigned int>> cell_to_face_vector(this->numberOfCells());
    for (size_t j=0; j<cell_to_face_vector.size(); ++j) {
      cell_to_face_vector[j].resize(cell_nb_faces[j]);
    }
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_to_face_vector[cell_number][cell_local_face] = l;
      }
      ++l;
    }
    cell_to_face_matrix = cell_to_face_vector;
  }

  FaceValuePerCell<bool> cell_face_is_reversed(*this);
  {
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_face_is_reversed(cell_number, cell_local_face) = reversed;
      }
    }

    m_cell_face_is_reversed = cell_face_is_reversed;
  }

  {
    auto& face_to_node_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];

    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = face.nodeIdList();
      ++l;
    }
    face_to_node_matrix = face_to_node_vector;
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

  const auto& cell_to_face_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
  auto& face_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];
  m_connectivity_computer.computeInverseConnectivityMatrix(cell_to_face_matrix,
                                                           face_to_cell_matrix);

  m_face_to_cell_local_face
      = m_connectivity_computer.computeLocalItemNumberInChildItem<TypeOfItem::face,
                                                                  TypeOfItem::cell>(*this);

  const auto& face_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];

#warning check that the number of cell per faces is <=2
  std::unordered_map<unsigned int, std::vector<unsigned int>> node_faces_map;
  for (size_t l=0; l<face_to_node_matrix.numRows(); ++l) {
    const auto& face_nodes = face_to_node_matrix.rowConst(l);
    for (size_t lr=0; lr<face_nodes.length; ++lr) {
      const unsigned int r = face_nodes(lr);
      node_faces_map[r].emplace_back(l);
    }
  }
  Kokkos::View<unsigned short*> node_nb_faces("node_nb_faces", this->numberOfNodes());
  size_t max_nb_face_per_node = 0;
  for (auto node_faces : node_faces_map) {
    max_nb_face_per_node = std::max(node_faces.second.size(), max_nb_face_per_node);
    node_nb_faces[node_faces.first] = node_faces.second.size();
  }
  m_node_nb_faces = node_nb_faces;

  Kokkos::View<unsigned int**> node_faces("node_faces", this->numberOfNodes(), max_nb_face_per_node);
  for (auto node_faces_vector : node_faces_map) {
    const unsigned int r = node_faces_vector.first;
    const std::vector<unsigned int>&  faces_vector = node_faces_vector.second;
    for (size_t l=0; l < faces_vector.size(); ++l) {
      node_faces(r, l) = faces_vector[l];
    }
  }
  m_node_faces = node_faces;
}

template<>
void Connectivity<2>::_computeFaceCellConnectivities()
{
  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  // In 2D faces are simply define
  typedef std::pair<unsigned int, unsigned short> CellFaceId;
  std::map<Face, std::vector<CellFaceId>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
    for (unsigned short r=0; r<cell_nodes.length; ++r) {
      unsigned int node0_id = cell_nodes(r);
      unsigned int node1_id = cell_nodes((r+1)%cell_nodes.length);
      if (node1_id<node0_id) {
        std::swap(node0_id, node1_id);
      }
      face_cells_map[Face({node0_id, node1_id})].push_back(std::make_pair(j, r));
    }
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

  {
    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = {face.m_node0_id, face.m_node1_id};
      ++l;
    }
    auto& face_to_node_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
    face_to_node_matrix = face_to_node_vector;
  }

  {
    std::vector<std::vector<unsigned int>> face_to_cell_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& [face, cell_info_vector] = face_cells_vector;
      for (const auto& cell_info : cell_info_vector) {
        face_to_cell_vector[l].push_back(cell_info.second);
      }
      ++l;
    }
    auto& face_to_cell_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];
    face_to_cell_matrix = face_to_cell_vector;
  }
}


template<size_t Dimension>
Connectivity<Dimension>::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector)
{
  auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];
  cell_to_node_matrix = cell_by_node_vector;

  Assert(this->numberOfCells()>0);

  {
    Kokkos::View<double*> inv_cell_nb_nodes("inv_cell_nb_nodes", this->numberOfCells());
    Kokkos::parallel_for(this->numberOfCells(), KOKKOS_LAMBDA(const int& j){
        const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
        inv_cell_nb_nodes[j] = 1./cell_nodes.length;
      });
    m_inv_cell_nb_nodes = inv_cell_nb_nodes;
  }

  auto& node_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];

  m_connectivity_computer.computeInverseConnectivityMatrix(cell_to_node_matrix,
                                                           node_to_cell_matrix);

  m_node_to_cell_local_node
      = m_connectivity_computer.computeLocalItemNumberInChildItem<TypeOfItem::node, TypeOfItem::cell>(*this);

  m_cell_to_node_local_cell
      = m_connectivity_computer.computeLocalItemNumberInChildItem<TypeOfItem::cell, TypeOfItem::node>(*this);

  if constexpr (Dimension>1) {
    this->_computeFaceCellConnectivities();
  }
}


template Connectivity1D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector);

template Connectivity2D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector);

template Connectivity3D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector);