#include <Connectivity.hpp>
#include <map>

template<>
void Connectivity<3>::_computeCellFaceAndFaceNodeConnectivities()
{
  using CellFaceInfo = std::tuple<unsigned int, unsigned short, bool>;

  const auto& cell_to_node_matrix
      = this->getMatrix(ItemType::cell, ItemType::node);

  CellValue<unsigned short> cell_nb_faces(*this);
  std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);

    switch (m_cell_type[j]) {
      case CellType::Tetrahedron: {
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(3)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(0),
                 cell_nodes(3),
                 cell_nodes(2)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(0),
                 cell_nodes(2),
                 cell_nodes(1)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      case CellType::Hexahedron: {
        // face 0
        Face f0({cell_nodes(3),
                 cell_nodes(2),
                 cell_nodes(1),
                 cell_nodes(0)});
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes(4),
                 cell_nodes(5),
                 cell_nodes(6),
                 cell_nodes(7)});
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes(0),
                 cell_nodes(4),
                 cell_nodes(7),
                 cell_nodes(3)});
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes(1),
                 cell_nodes(2),
                 cell_nodes(6),
                 cell_nodes(5)});
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));

        // face 4
        Face f4({cell_nodes(0),
                 cell_nodes(1),
                 cell_nodes(5),
                 cell_nodes(4)});
        face_cells_map[f4].emplace_back(std::make_tuple(j, 4, f4.reversed()));

        // face 5
        Face f5({cell_nodes(3),
                 cell_nodes(7),
                 cell_nodes(6),
                 cell_nodes(2)});
        face_cells_map[f5].emplace_back(std::make_tuple(j, 5, f5.reversed()));

        cell_nb_faces[j] = 6;
        break;
      }
      default: {
        std::cerr << "unexpected cell type!\n";
        std::exit(0);
      }
    }
  }

  {
    auto& cell_to_face_matrix
        = m_item_to_item_matrix[itemTId(ItemType::cell)][itemTId(ItemType::face)];
    std::vector<std::vector<unsigned int>> cell_to_face_vector(this->numberOfCells());
    for (size_t j=0; j<cell_to_face_vector.size(); ++j) {
      cell_to_face_vector[j].resize(cell_nb_faces[j]);
    }
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_to_face_vector[cell_number][cell_local_face] = l;
      }
      ++l;
    }
    cell_to_face_matrix = cell_to_face_vector;
  }

  FaceValuePerCell<bool> cell_face_is_reversed(*this);
  {
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
        cell_face_is_reversed(cell_number, cell_local_face) = reversed;
      }
    }

    m_cell_face_is_reversed = cell_face_is_reversed;
  }

  {
    auto& face_to_node_matrix
        = m_item_to_item_matrix[itemTId(ItemType::face)][itemTId(ItemType::node)];

    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = face.nodeIdList();
      ++l;
    }
    face_to_node_matrix = face_to_node_vector;
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

#warning check that the number of cell per faces is <=2
}

template<>
void Connectivity<2>::_computeCellFaceAndFaceNodeConnectivities()
{
  const auto& cell_to_node_matrix
      = this->getMatrix(ItemType::cell, ItemType::node);

  // In 2D faces are simply define
  using CellFaceId = std::pair<unsigned int, unsigned short>;
  std::map<Face, std::vector<CellFaceId>> face_cells_map;
  for (unsigned int j=0; j<this->numberOfCells(); ++j) {
    const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
    for (unsigned short r=0; r<cell_nodes.length; ++r) {
      unsigned int node0_id = cell_nodes(r);
      unsigned int node1_id = cell_nodes((r+1)%cell_nodes.length);
      if (node1_id<node0_id) {
        std::swap(node0_id, node1_id);
      }
      face_cells_map[Face({node0_id, node1_id})].push_back(std::make_pair(j, r));
    }
  }

  {
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const Face& face = face_cells_vector.first;
      m_face_number_map[face] = l;
      ++l;
    }
  }

  {
    std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_info : face_cells_map) {
      const Face& face = face_info.first;
      face_to_node_vector[l] = {face.m_node0_id, face.m_node1_id};
      ++l;
    }
    auto& face_to_node_matrix
        = m_item_to_item_matrix[itemTId(ItemType::face)][itemTId(ItemType::node)];
    face_to_node_matrix = face_to_node_vector;
  }

  {
    std::vector<std::vector<unsigned int>> face_to_cell_vector(face_cells_map.size());
    int l=0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& [face, cell_info_vector] = face_cells_vector;
      for (const auto& cell_info : cell_info_vector) {
        face_to_cell_vector[l].push_back(cell_info.second);
      }
      ++l;
    }
    auto& face_to_cell_matrix
        = m_item_to_item_matrix[itemTId(ItemType::face)][itemTId(ItemType::cell)];
    face_to_cell_matrix = face_to_cell_vector;
  }
}


template<size_t Dimension>
Connectivity<Dimension>::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector,
             const std::vector<CellType>& cell_type_vector)
{
  Assert(cell_by_node_vector.size() == cell_type_vector.size());

  auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemTId(ItemType::cell)][itemTId(ItemType::node)];
  cell_to_node_matrix = cell_by_node_vector;

  Assert(this->numberOfCells()>0);

  {
    CellValue<CellType> cell_type(*this);
    Kokkos::parallel_for(this->numberOfCells(), KOKKOS_LAMBDA(const int& j){
        cell_type[j] = cell_type_vector[j];
      });
    m_cell_type = cell_type;
  }
  {
    CellValue<double> inv_cell_nb_nodes(*this);
    Kokkos::parallel_for(this->numberOfCells(), KOKKOS_LAMBDA(const int& j){
        const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
        inv_cell_nb_nodes[j] = 1./cell_nodes.length;
      });
    m_inv_cell_nb_nodes = inv_cell_nb_nodes;
  }

  if constexpr (Dimension>1) {
    this->_computeCellFaceAndFaceNodeConnectivities();
  }
}


template Connectivity1D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector,
             const std::vector<CellType>& cell_type_vector);

template Connectivity2D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector,
             const std::vector<CellType>& cell_type_vector);

template Connectivity3D::
Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector,
             const std::vector<CellType>& cell_type_vector);