#include <mesh/ConnectivityBuilderBase.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ItemId.hpp>
#include <mesh/Mesh.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>

#include <map>
#include <unordered_map>
#include <vector>

template <size_t Dimension>
void
ConnectivityBuilderBase::_computeCellFaceAndFaceNodeConnectivities(ConnectivityDescriptor& descriptor)
{
  static_assert((Dimension == 2) or (Dimension == 3), "Invalid dimension to compute cell-face connectivities");
  using CellFaceInfo = std::tuple<CellId, unsigned short, bool>;
  using Face         = ConnectivityFace<Dimension>;

  const auto& node_number_vector = descriptor.node_number_vector;
  Array<unsigned short> cell_nb_faces(descriptor.cell_to_node_vector.size());
  std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
  for (CellId j = 0; j < descriptor.cell_to_node_vector.size(); ++j) {
    const auto& cell_nodes = descriptor.cell_to_node_vector[j];

    if constexpr (Dimension == 2) {
      switch (descriptor.cell_type_vector[j]) {
      case CellType::Triangle: {
        cell_nb_faces[j] = 3;
        // face 0
        Face f0({cell_nodes[1], cell_nodes[2]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[2], cell_nodes[0]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[1]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));
        break;
      }
      case CellType::Quadrangle: {
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes[0], cell_nodes[1]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[1], cell_nodes[2]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[2], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[3], cell_nodes[0]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 2";
        throw UnexpectedError(error_msg.str());
      }
      }
    } else if constexpr (Dimension == 3) {
      switch (descriptor.cell_type_vector[j]) {
      case CellType::Hexahedron: {
        // face 0
        Face f0({cell_nodes[3], cell_nodes[2], cell_nodes[1], cell_nodes[0]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[4], cell_nodes[5], cell_nodes[6], cell_nodes[7]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[4], cell_nodes[7], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[1], cell_nodes[2], cell_nodes[6], cell_nodes[5]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));

        // face 4
        Face f4({cell_nodes[0], cell_nodes[1], cell_nodes[5], cell_nodes[4]}, node_number_vector);
        face_cells_map[f4].emplace_back(std::make_tuple(j, 4, f4.reversed()));

        // face 5
        Face f5({cell_nodes[3], cell_nodes[7], cell_nodes[6], cell_nodes[2]}, node_number_vector);
        face_cells_map[f5].emplace_back(std::make_tuple(j, 5, f5.reversed()));

        cell_nb_faces[j] = 6;
        break;
      }
      case CellType::Tetrahedron: {
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes[1], cell_nodes[2], cell_nodes[3]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[0], cell_nodes[3], cell_nodes[2]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[1], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[0], cell_nodes[2], cell_nodes[1]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      case CellType::Pyramid: {
        cell_nb_faces[j] = cell_nodes.size();
        std::vector<unsigned int> base_nodes;
        std::copy_n(cell_nodes.begin(), cell_nodes.size() - 1, std::back_inserter(base_nodes));

        // base face
        {
          Face base_face(base_nodes, node_number_vector);
          face_cells_map[base_face].emplace_back(std::make_tuple(j, 0, base_face.reversed()));
        }
        // side faces
        const auto pyramid_vertex = cell_nodes[cell_nodes.size() - 1];
        for (size_t i_node = 0; i_node < base_nodes.size(); ++i_node) {
          Face side_face({base_nodes[(i_node + 1) % base_nodes.size()], base_nodes[i_node], pyramid_vertex},
                         node_number_vector);
          face_cells_map[side_face].emplace_back(std::make_tuple(j, i_node + 1, side_face.reversed()));
        }
        break;
      }
      case CellType::Diamond: {
        cell_nb_faces[j] = 2 * (cell_nodes.size() - 2);
        std::vector<unsigned int> base_nodes;
        std::copy_n(cell_nodes.begin() + 1, cell_nodes.size() - 2, std::back_inserter(base_nodes));

        {   // top faces
          const auto top_vertex = cell_nodes[cell_nodes.size() - 1];
          for (size_t i_node = 0; i_node < base_nodes.size(); ++i_node) {
            Face top_face({base_nodes[i_node], base_nodes[(i_node + 1) % base_nodes.size()], top_vertex},
                          node_number_vector);
            face_cells_map[top_face].emplace_back(std::make_tuple(j, i_node, top_face.reversed()));
          }
        }

        {   // bottom faces
          const auto bottom_vertex = cell_nodes[0];
          for (size_t i_node = 0; i_node < base_nodes.size(); ++i_node) {
            Face bottom_face({base_nodes[(i_node + 1) % base_nodes.size()], base_nodes[i_node], bottom_vertex},
                             node_number_vector);
            face_cells_map[bottom_face].emplace_back(
              std::make_tuple(j, i_node + base_nodes.size(), bottom_face.reversed()));
          }
        }
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 3";
        throw UnexpectedError(error_msg.str());
      }
      }
    }
  }

  {
    descriptor.cell_to_face_vector.resize(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_to_face_vector.size(); ++j) {
      descriptor.cell_to_face_vector[j].resize(cell_nb_faces[j]);
    }
    FaceId l = 0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj = 0; lj < cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed]         = cells_vector[lj];
        descriptor.cell_to_face_vector[cell_number][cell_local_face] = l;
      }
      ++l;
    }
  }

  {
    descriptor.cell_face_is_reversed_vector.resize(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_face_is_reversed_vector.size(); ++j) {
      descriptor.cell_face_is_reversed_vector[j] = Array<bool>(cell_nb_faces[j]);
    }
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj = 0; lj < cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed]                  = cells_vector[lj];
        descriptor.cell_face_is_reversed_vector[cell_number][cell_local_face] = reversed;
      }
    }
  }

  {
    descriptor.face_to_node_vector.resize(face_cells_map.size());
    int l = 0;
    for (const auto& face_info : face_cells_map) {
      const Face& face                  = face_info.first;
      descriptor.face_to_node_vector[l] = face.nodeIdList();
      ++l;
    }
  }

  {
    // Face numbers may change if numbers are provided in the file
    descriptor.face_number_vector.resize(face_cells_map.size());
    for (size_t l = 0; l < face_cells_map.size(); ++l) {
      descriptor.face_number_vector[l] = l;
    }
  }
}

template <size_t Dimension>
void
ConnectivityBuilderBase::_computeFaceEdgeAndEdgeNodeAndCellEdgeConnectivities(ConnectivityDescriptor& descriptor)
{
  static_assert(Dimension == 3, "Invalid dimension to compute face-edge connectivities");
  using FaceEdgeInfo = std::tuple<FaceId, unsigned short, bool>;
  using Edge         = ConnectivityFace<2>;

  const auto& node_number_vector = descriptor.node_number_vector;
  Array<unsigned short> face_nb_edges(descriptor.face_to_node_vector.size());
  std::map<Edge, std::vector<FaceEdgeInfo>> edge_faces_map;
  for (FaceId l = 0; l < descriptor.face_to_node_vector.size(); ++l) {
    const auto& face_nodes = descriptor.face_to_node_vector[l];

    face_nb_edges[l] = face_nodes.size();
    for (size_t r = 0; r < face_nodes.size() - 1; ++r) {
      Edge e({face_nodes[r], face_nodes[r + 1]}, node_number_vector);
      edge_faces_map[e].emplace_back(std::make_tuple(l, r, e.reversed()));
    }
    {
      Edge e({face_nodes[face_nodes.size() - 1], face_nodes[0]}, node_number_vector);
      edge_faces_map[e].emplace_back(std::make_tuple(l, face_nodes.size() - 1, e.reversed()));
    }
  }

  std::unordered_map<Edge, EdgeId, typename Edge::Hash> edge_id_map;
  {
    descriptor.face_to_edge_vector.resize(descriptor.face_to_node_vector.size());
    for (FaceId l = 0; l < descriptor.face_to_node_vector.size(); ++l) {
      descriptor.face_to_edge_vector[l].resize(face_nb_edges[l]);
    }
    EdgeId e = 0;
    for (const auto& edge_faces_vector : edge_faces_map) {
      const auto& faces_vector = edge_faces_vector.second;
      for (unsigned short l = 0; l < faces_vector.size(); ++l) {
        const auto& [face_number, face_local_edge, reversed]         = faces_vector[l];
        descriptor.face_to_edge_vector[face_number][face_local_edge] = e;
      }
      edge_id_map[edge_faces_vector.first] = e;
      ++e;
    }
  }

  {
    descriptor.face_edge_is_reversed_vector.resize(descriptor.face_to_node_vector.size());
    for (FaceId j = 0; j < descriptor.face_edge_is_reversed_vector.size(); ++j) {
      descriptor.face_edge_is_reversed_vector[j] = Array<bool>(face_nb_edges[j]);
    }
    for (const auto& edge_faces_vector : edge_faces_map) {
      const auto& faces_vector = edge_faces_vector.second;
      for (unsigned short lj = 0; lj < faces_vector.size(); ++lj) {
        const auto& [face_number, face_local_edge, reversed]                  = faces_vector[lj];
        descriptor.face_edge_is_reversed_vector[face_number][face_local_edge] = reversed;
      }
    }
  }

  {
    descriptor.edge_to_node_vector.resize(edge_faces_map.size());
    int e = 0;
    for (const auto& edge_info : edge_faces_map) {
      const Edge& edge                  = edge_info.first;
      descriptor.edge_to_node_vector[e] = edge.nodeIdList();
      ++e;
    }
  }

  {
    // Edge numbers may change if numbers are provided in the file
    descriptor.edge_number_vector.resize(edge_faces_map.size());
    for (size_t e = 0; e < edge_faces_map.size(); ++e) {
      descriptor.edge_number_vector[e] = e;
    }
  }

  {
    descriptor.cell_to_edge_vector.reserve(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_to_node_vector.size(); ++j) {
      const auto& cell_nodes = descriptor.cell_to_node_vector[j];

      switch (descriptor.cell_type_vector[j]) {
      case CellType::Tetrahedron: {
        constexpr int local_edge[6][2] = {{0, 1}, {0, 2}, {0, 3}, {1, 2}, {2, 3}, {3, 1}};
        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(6);
        for (int i_edge = 0; i_edge < 6; ++i_edge) {
          const auto e = local_edge[i_edge];
          Edge edge{{cell_nodes[e[0]], cell_nodes[e[1]]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }
        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);
        break;
      }
      case CellType::Hexahedron: {
        constexpr int local_edge[12][2] = {{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 5}, {5, 6},
                                           {6, 7}, {7, 4}, {0, 4}, {1, 5}, {2, 6}, {3, 7}};
        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(12);
        for (int i_edge = 0; i_edge < 12; ++i_edge) {
          const auto e = local_edge[i_edge];
          Edge edge{{cell_nodes[e[0]], cell_nodes[e[1]]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }
        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);
        break;
      }
      case CellType::Pyramid: {
        const size_t number_of_edges = 2 * cell_nodes.size();
        std::vector<unsigned int> base_nodes;
        std::copy_n(cell_nodes.begin(), cell_nodes.size() - 1, std::back_inserter(base_nodes));

        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(number_of_edges);
        for (size_t i_edge = 0; i_edge < base_nodes.size(); ++i_edge) {
          Edge edge{{base_nodes[i_edge], base_nodes[(i_edge + 1) % base_nodes.size()]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }

        const unsigned int top_vertex = cell_nodes[cell_nodes.size() - 1];
        for (size_t i_edge = 0; i_edge < base_nodes.size(); ++i_edge) {
          Edge edge{{base_nodes[i_edge], top_vertex}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }
        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);
        break;
      }
      case CellType::Diamond: {
        const size_t number_of_edges = 3 * cell_nodes.size();
        std::vector<unsigned int> base_nodes;
        std::copy_n(cell_nodes.begin() + 1, cell_nodes.size() - 2, std::back_inserter(base_nodes));

        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(number_of_edges);
        for (size_t i_edge = 0; i_edge < base_nodes.size(); ++i_edge) {
          Edge edge{{base_nodes[i_edge], base_nodes[(i_edge + 1) % base_nodes.size()]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }

        const unsigned int top_vertex = cell_nodes[cell_nodes.size() - 1];
        for (size_t i_edge = 0; i_edge < base_nodes.size(); ++i_edge) {
          Edge edge{{base_nodes[i_edge], top_vertex}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }

        const unsigned int bottom_vertex = cell_nodes[0];
        for (size_t i_edge = 0; i_edge < base_nodes.size(); ++i_edge) {
          Edge edge{{base_nodes[i_edge], bottom_vertex}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }

        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);

        break;
      }
      default: {
        std::stringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 3";
        throw NotImplementedError(error_msg.str());
      }
      }
    }
  }
}

template void ConnectivityBuilderBase::_computeCellFaceAndFaceNodeConnectivities<2>(ConnectivityDescriptor& descriptor);
template void ConnectivityBuilderBase::_computeCellFaceAndFaceNodeConnectivities<3>(ConnectivityDescriptor& descriptor);

template void ConnectivityBuilderBase::_computeFaceEdgeAndEdgeNodeAndCellEdgeConnectivities<3>(
  ConnectivityDescriptor& descriptor);