#include <mesh/MeshBuilderBase.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ConnectivityDispatcher.hpp>
#include <mesh/ItemId.hpp>
#include <mesh/Mesh.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>

#include <vector>

template <int Dimension>
void
MeshBuilderBase::_dispatch()
{
  if (parallel::size() == 1) {
    return;
  }

  using ConnectivityType = Connectivity<Dimension>;
  using Rd               = TinyVector<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;

  if (not m_mesh) {
    ConnectivityDescriptor descriptor;
    std::shared_ptr connectivity = ConnectivityType::build(descriptor);
    NodeValue<Rd> xr;
    m_mesh = std::make_shared<MeshType>(connectivity, xr);
  }
  const MeshType& mesh = static_cast<const MeshType&>(*m_mesh);

  ConnectivityDispatcher<Dimension> dispatcher(mesh.connectivity());

  std::shared_ptr dispatched_connectivity = dispatcher.dispatchedConnectivity();
  NodeValue<Rd> dispatched_xr             = dispatcher.dispatch(mesh.xr());

  m_mesh = std::make_shared<MeshType>(dispatched_connectivity, dispatched_xr);
}

template <size_t Dimension>
void
MeshBuilderBase::_computeCellFaceAndFaceNodeConnectivities(ConnectivityDescriptor& descriptor)
{
  static_assert((Dimension == 2) or (Dimension == 3), "Invalid dimension to compute cell-face connectivities");
  using CellFaceInfo = std::tuple<CellId, unsigned short, bool>;
  using Face         = ConnectivityFace<Dimension>;

  const auto& node_number_vector = descriptor.node_number_vector;
  Array<unsigned short> cell_nb_faces(descriptor.cell_to_node_vector.size());
  std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
  for (CellId j = 0; j < descriptor.cell_to_node_vector.size(); ++j) {
    const auto& cell_nodes = descriptor.cell_to_node_vector[j];

    if constexpr (Dimension == 2) {
      switch (descriptor.cell_type_vector[j]) {
      case CellType::Triangle: {
        cell_nb_faces[j] = 3;
        // face 0
        Face f0({cell_nodes[1], cell_nodes[2]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[2], cell_nodes[0]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[1]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));
        break;
      }
      case CellType::Quadrangle: {
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes[0], cell_nodes[1]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[1], cell_nodes[2]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[2], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[3], cell_nodes[0]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 2";
        throw UnexpectedError(error_msg.str());
      }
      }
    } else if constexpr (Dimension == 3) {
      switch (descriptor.cell_type_vector[j]) {
      case CellType::Tetrahedron: {
        cell_nb_faces[j] = 4;
        // face 0
        Face f0({cell_nodes[1], cell_nodes[2], cell_nodes[3]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[0], cell_nodes[3], cell_nodes[2]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[1], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[0], cell_nodes[2], cell_nodes[1]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));
        break;
      }
      case CellType::Hexahedron: {
        // face 0
        Face f0({cell_nodes[3], cell_nodes[2], cell_nodes[1], cell_nodes[0]}, node_number_vector);
        face_cells_map[f0].emplace_back(std::make_tuple(j, 0, f0.reversed()));

        // face 1
        Face f1({cell_nodes[4], cell_nodes[5], cell_nodes[6], cell_nodes[7]}, node_number_vector);
        face_cells_map[f1].emplace_back(std::make_tuple(j, 1, f1.reversed()));

        // face 2
        Face f2({cell_nodes[0], cell_nodes[4], cell_nodes[7], cell_nodes[3]}, node_number_vector);
        face_cells_map[f2].emplace_back(std::make_tuple(j, 2, f2.reversed()));

        // face 3
        Face f3({cell_nodes[1], cell_nodes[2], cell_nodes[6], cell_nodes[5]}, node_number_vector);
        face_cells_map[f3].emplace_back(std::make_tuple(j, 3, f3.reversed()));

        // face 4
        Face f4({cell_nodes[0], cell_nodes[1], cell_nodes[5], cell_nodes[4]}, node_number_vector);
        face_cells_map[f4].emplace_back(std::make_tuple(j, 4, f4.reversed()));

        // face 5
        Face f5({cell_nodes[3], cell_nodes[7], cell_nodes[6], cell_nodes[2]}, node_number_vector);
        face_cells_map[f5].emplace_back(std::make_tuple(j, 5, f5.reversed()));

        cell_nb_faces[j] = 6;
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 3";
        throw UnexpectedError(error_msg.str());
      }
      }
    }
  }

  {
    descriptor.cell_to_face_vector.resize(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_to_face_vector.size(); ++j) {
      descriptor.cell_to_face_vector[j].resize(cell_nb_faces[j]);
    }
    FaceId l = 0;
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj = 0; lj < cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed]         = cells_vector[lj];
        descriptor.cell_to_face_vector[cell_number][cell_local_face] = l;
      }
      ++l;
    }
  }

  {
    descriptor.cell_face_is_reversed_vector.resize(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_face_is_reversed_vector.size(); ++j) {
      descriptor.cell_face_is_reversed_vector[j] = Array<bool>(cell_nb_faces[j]);
    }
    for (const auto& face_cells_vector : face_cells_map) {
      const auto& cells_vector = face_cells_vector.second;
      for (unsigned short lj = 0; lj < cells_vector.size(); ++lj) {
        const auto& [cell_number, cell_local_face, reversed]                  = cells_vector[lj];
        descriptor.cell_face_is_reversed_vector[cell_number][cell_local_face] = reversed;
      }
    }
  }

  {
    descriptor.face_to_node_vector.resize(face_cells_map.size());
    int l = 0;
    for (const auto& face_info : face_cells_map) {
      const Face& face                  = face_info.first;
      descriptor.face_to_node_vector[l] = face.nodeIdList();
      ++l;
    }
  }

  {
    // Face numbers may change if numbers are provided in the file
    descriptor.face_number_vector.resize(face_cells_map.size());
    for (size_t l = 0; l < face_cells_map.size(); ++l) {
      descriptor.face_number_vector[l] = l;
    }
  }
}

template <size_t Dimension>
void
MeshBuilderBase::_computeFaceEdgeAndEdgeNodeAndCellEdgeConnectivities(ConnectivityDescriptor& descriptor)
{
  static_assert(Dimension == 3, "Invalid dimension to compute face-edge connectivities");
  using FaceEdgeInfo = std::tuple<FaceId, unsigned short, bool>;
  using Edge         = ConnectivityFace<2>;

  const auto& node_number_vector = descriptor.node_number_vector;
  Array<unsigned short> face_nb_edges(descriptor.face_to_node_vector.size());
  std::map<Edge, std::vector<FaceEdgeInfo>> edge_faces_map;
  for (FaceId l = 0; l < descriptor.face_to_node_vector.size(); ++l) {
    const auto& face_nodes = descriptor.face_to_node_vector[l];

    face_nb_edges[l] = face_nodes.size();
    for (size_t r = 0; r < face_nodes.size() - 1; ++r) {
      Edge e({face_nodes[r], face_nodes[r + 1]}, node_number_vector);
      edge_faces_map[e].emplace_back(std::make_tuple(l, r, e.reversed()));
    }
    {
      Edge e({face_nodes[face_nodes.size() - 1], face_nodes[0]}, node_number_vector);
      edge_faces_map[e].emplace_back(std::make_tuple(l, face_nodes.size() - 1, e.reversed()));
    }
  }

  std::unordered_map<Edge, EdgeId, typename Edge::Hash> edge_id_map;
  {
    descriptor.face_to_edge_vector.resize(descriptor.face_to_node_vector.size());
    for (FaceId l = 0; l < descriptor.face_to_node_vector.size(); ++l) {
      descriptor.face_to_edge_vector[l].resize(face_nb_edges[l]);
    }
    EdgeId e = 0;
    for (const auto& edge_faces_vector : edge_faces_map) {
      const auto& faces_vector = edge_faces_vector.second;
      for (unsigned short l = 0; l < faces_vector.size(); ++l) {
        const auto& [face_number, face_local_edge, reversed]         = faces_vector[l];
        descriptor.face_to_edge_vector[face_number][face_local_edge] = e;
      }
      edge_id_map[edge_faces_vector.first] = e;
      ++e;
    }
  }

  {
    descriptor.face_edge_is_reversed_vector.resize(descriptor.face_to_node_vector.size());
    for (FaceId j = 0; j < descriptor.face_edge_is_reversed_vector.size(); ++j) {
      descriptor.face_edge_is_reversed_vector[j] = Array<bool>(face_nb_edges[j]);
    }
    for (const auto& edge_faces_vector : edge_faces_map) {
      const auto& faces_vector = edge_faces_vector.second;
      for (unsigned short lj = 0; lj < faces_vector.size(); ++lj) {
        const auto& [face_number, face_local_edge, reversed]                  = faces_vector[lj];
        descriptor.face_edge_is_reversed_vector[face_number][face_local_edge] = reversed;
      }
    }
  }

  {
    descriptor.edge_to_node_vector.resize(edge_faces_map.size());
    int e = 0;
    for (const auto& edge_info : edge_faces_map) {
      const Edge& edge                  = edge_info.first;
      descriptor.edge_to_node_vector[e] = edge.nodeIdList();
      ++e;
    }
  }

  {
    // Edge numbers may change if numbers are provided in the file
    descriptor.edge_number_vector.resize(edge_faces_map.size());
    for (size_t e = 0; e < edge_faces_map.size(); ++e) {
      descriptor.edge_number_vector[e] = e;
    }
  }

  {
    descriptor.cell_to_node_vector.reserve(descriptor.cell_to_node_vector.size());
    for (CellId j = 0; j < descriptor.cell_to_node_vector.size(); ++j) {
      const auto& cell_nodes = descriptor.cell_to_node_vector[j];

      switch (descriptor.cell_type_vector[j]) {
      case CellType::Tetrahedron: {
        constexpr int local_edge[6][2] = {{0, 1}, {0, 2}, {0, 3}, {1, 2}, {2, 3}, {3, 1}};
        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(6);
        for (int i_edge = 0; i_edge < 6; ++i_edge) {
          const auto e = local_edge[i_edge];
          Edge edge{{cell_nodes[e[0]], cell_nodes[e[1]]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }
        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);
        break;
      }
      case CellType::Hexahedron: {
        constexpr int local_edge[12][2] = {{0, 1}, {1, 2}, {2, 3}, {3, 0}, {4, 5}, {5, 6},
                                           {6, 7}, {7, 4}, {0, 4}, {1, 5}, {2, 6}, {3, 7}};
        std::vector<unsigned int> cell_edge_vector;
        cell_edge_vector.reserve(12);
        for (int i_edge = 0; i_edge < 12; ++i_edge) {
          const auto e = local_edge[i_edge];
          Edge edge{{cell_nodes[e[0]], cell_nodes[e[1]]}, node_number_vector};
          auto i = edge_id_map.find(edge);
          if (i == edge_id_map.end()) {
            throw NormalError("could not find this edge");
          }
          cell_edge_vector.push_back(i->second);
        }
        descriptor.cell_to_edge_vector.emplace_back(cell_edge_vector);
        break;
      }
      default: {
        std::stringstream error_msg;
        error_msg << name(descriptor.cell_type_vector[j]) << ": unexpected cell type in dimension 3";
        throw UnexpectedError(error_msg.str());
      }
      }
    }
  }
}

template void MeshBuilderBase::_dispatch<1>();
template void MeshBuilderBase::_dispatch<2>();
template void MeshBuilderBase::_dispatch<3>();

template void MeshBuilderBase::_computeCellFaceAndFaceNodeConnectivities<2>(ConnectivityDescriptor& descriptor);
template void MeshBuilderBase::_computeCellFaceAndFaceNodeConnectivities<3>(ConnectivityDescriptor& descriptor);

template void MeshBuilderBase::_computeFaceEdgeAndEdgeNodeAndCellEdgeConnectivities<3>(
  ConnectivityDescriptor& descriptor);
