#include <Connectivity.hpp>
#include <map>

template<>
void Connectivity<3>::_computeFaceCellConnectivities()
{
  Kokkos::View<unsigned short*> cell_nb_faces("cell_nb_faces", this->numberOfCells());

  typedef std::tuple<unsigned int, unsigned short, bool> CellFaceInfo;

  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);

    switch (cell_nodes.length) {
      case 4: { // tetrahedron
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(3)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(0),
                 cell_nodes(3),
                 cell_nodes(2)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(0),
                 cell_nodes(2),
                 cell_nodes(1)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      case 8: { // hexahedron
        // face 0
        Face f0({cell_nodes(3),
                 cell_nodes(2),
                 cell_nodes(1),
                 cell_nodes(0)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(4),
                 cell_nodes(5),
                 cell_nodes(6),
                 cell_nodes(7)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(4),
                 cell_nodes(7),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(6),
                 cell_nodes(5)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));

        // face 4
        Face f4({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(5),
                 cell_nodes(4)});
        face_cells_map[f4].emplace_back(std::make_tuple(j, 4, f4.reversed()));

        // face 5
        Face f5({cell_nodes(3),
                 cell_nodes(7),
                 cell_nodes(6),
                 cell_nodes(2)});
        face_cells_map[f5].emplace_back(std::make_tuple(j, 5, f5.reversed()));

        cell_nb_faces[j] = 6;
        break;
      }
      default: {
        std::cerr << "unexpected cell type!\n";
        std::exit(0);
      }
    }
  }

  {
    std::vector<std::vector<unsigned int>> cell_to_face_vector(this->numberOfCells());
    for (size_t j=0; j<cell_to_face_vector.size(); ++j) {
      cell_to_face_vector[j].resize(cell_nb_faces[j]);
    }
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_to_face_vector[cell_number][cell_local_face] = l;
      }
      ++l;
    }
    m_cell_to_face_matrix = cell_to_face_vector;
  }

  FaceValuePerCell<bool> cell_face_is_reversed(*this);
  {
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_face_is_reversed(cell_number, cell_local_face) = reversed;
      }
    }

    m_cell_face_is_reversed = cell_face_is_reversed;
  }

  {
    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = face.nodeIdList();
      ++l;
    }
    m_face_to_node_matrix = face_to_node_vector;
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

  m_connectivity_computer.computeInverseConnectivityMatrix(m_cell_to_face_matrix,
                                                           m_face_to_cell_matrix);

  m_face_to_cell_local_face = CellValuePerFace<unsigned short>(*this);

  m_connectivity_computer.computeLocalChildItemNumberInItem(m_cell_to_face_matrix,
                                                            m_face_to_cell_matrix,
                                                            m_face_to_cell_local_face);

#warning check that the number of cell per faces is <=2
  std::unordered_map<unsigned int, std::vector<unsigned int>> node_faces_map;
  for (size_t l=0; l<m_face_to_node_matrix.numRows(); ++l) {
    const auto& face_nodes = m_face_to_node_matrix.rowConst(l);
    for (size_t lr=0; lr<face_nodes.length; ++lr) {
      const unsigned int r = face_nodes(lr);
      node_faces_map[r].emplace_back(l);
    }
  }
  Kokkos::View<unsigned short*> node_nb_faces("node_nb_faces", this->numberOfNodes());
  size_t max_nb_face_per_node = 0;
  for (auto node_faces : node_faces_map) {
    max_nb_face_per_node = std::max(node_faces.second.size(), max_nb_face_per_node);
    node_nb_faces[node_faces.first] = node_faces.second.size();
  }
  m_node_nb_faces = node_nb_faces;

  Kokkos::View<unsigned int**> node_faces("node_faces", this->numberOfNodes(), max_nb_face_per_node);
  for (auto node_faces_vector : node_faces_map) {
    const unsigned int r = node_faces_vector.first;
    const std::vector<unsigned int>&  faces_vector = node_faces_vector.second;
    for (size_t l=0; l < faces_vector.size(); ++l) {
      node_faces(r, l) = faces_vector[l];
    }
  }
  m_node_faces = node_faces;
}

template<>
void Connectivity<2>::_computeFaceCellConnectivities()
{
  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  // In 2D faces are simply define
  typedef std::pair<unsigned int, unsigned short> CellFaceId;
  std::map<Face, std::vector<CellFaceId>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
    for (unsigned short r=0; r<cell_nodes.length; ++r) {
      unsigned int node0_id = cell_nodes(r);
      unsigned int node1_id = cell_nodes((r+1)%cell_nodes.length);
      if (node1_id<node0_id) {
        std::swap(node0_id, node1_id);
      }
      face_cells_map[Face({node0_id, node1_id})].push_back(std::make_pair(j, r));
    }
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

  {
    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = {face.m_node0_id, face.m_node1_id};
      ++l;
    }
    m_face_to_node_matrix = face_to_node_vector;
  }

  {
    std::vector<std::vector<unsigned int>> face_to_cell_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& [face, cell_info_vector] = face_cells_vector;
      for (const auto& cell_info : cell_info_vector) {
        face_to_cell_vector[l].push_back(cell_info.second);
      }
      ++l;
    }
    m_face_to_cell_matrix = face_to_cell_vector;
  }
}