#include <mesh/MedianDualConnectivityBuilder.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ConnectivityDispatcher.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/PrimalToMedianDualConnectivityDataMapper.hpp>
#include <mesh/RefId.hpp>
#include <utils/Array.hpp>
#include <utils/Messenger.hpp>
#include <utils/Stringify.hpp>

#include <vector>

template <>
void
MedianDualConnectivityBuilder::_buildConnectivityDescriptor<2>(const Connectivity<2>& primal_connectivity,
                                                               ConnectivityDescriptor& dual_descriptor)
{
  const size_t primal_number_of_nodes = primal_connectivity.numberOfNodes();
  const size_t primal_number_of_faces = primal_connectivity.numberOfFaces();
  const size_t primal_number_of_cells = primal_connectivity.numberOfCells();

  const auto& primal_node_number = primal_connectivity.nodeNumber();
  const auto& primal_face_number = primal_connectivity.faceNumber();
  const auto& primal_cell_number = primal_connectivity.cellNumber();

  const size_t primal_number_of_boundary_faces = [&] {
    size_t number_of_boundary_faces        = 0;
    const auto& primal_face_is_owned       = primal_connectivity.faceIsOwned();
    const auto& primal_face_to_cell_matrix = primal_connectivity.faceToCellMatrix();

    parallel_reduce(
      primal_number_of_faces,
      PUGS_LAMBDA(const FaceId face_id, size_t& number_of_boundary_faces) {
        number_of_boundary_faces +=
          (primal_face_is_owned[face_id] and (primal_face_to_cell_matrix[face_id].size() == 1));
      },
      number_of_boundary_faces);
    return number_of_boundary_faces;
  }();

  const size_t primal_number_of_boundary_nodes = primal_number_of_boundary_faces;

  const size_t dual_number_of_nodes = primal_number_of_cells + primal_number_of_faces + primal_number_of_boundary_nodes;
  const size_t dual_number_of_cells = primal_number_of_nodes;

  {
    m_primal_node_to_dual_cell_map = NodeIdToCellIdMap{primal_number_of_nodes};
    CellId dual_cell_id            = 0;
    for (NodeId primal_node_id = 0; primal_node_id < primal_number_of_nodes; ++primal_node_id) {
      m_primal_node_to_dual_cell_map[primal_node_id] = std::make_pair(primal_node_id, dual_cell_id++);
    }
  }

  NodeValue<NodeId> node_to_dual_node_correpondance{primal_connectivity};
  node_to_dual_node_correpondance.fill(std::numeric_limits<NodeId>::max());

  {
    NodeId dual_node_id = 0;

    m_primal_cell_to_dual_node_map = CellIdToNodeIdMap{primal_number_of_cells};
    for (CellId primal_cell_id = 0; primal_cell_id < primal_number_of_cells; ++primal_cell_id) {
      m_primal_cell_to_dual_node_map[primal_cell_id] = std::make_pair(primal_cell_id, dual_node_id++);
    }

    m_primal_face_to_dual_node_map = FaceIdToNodeIdMap{primal_number_of_faces};
    for (FaceId primal_face_id = 0; primal_face_id < primal_number_of_faces; ++primal_face_id) {
      m_primal_face_to_dual_node_map[primal_face_id] = std::make_pair(primal_face_id, dual_node_id++);
    }

    const auto& primal_face_is_owned       = primal_connectivity.faceIsOwned();
    const auto& primal_face_to_cell_matrix = primal_connectivity.faceToCellMatrix();
    const auto& primal_face_to_node_matrix = primal_connectivity.faceToNodeMatrix();

    m_primal_boundary_node_to_dual_node_map = NodeIdToNodeIdMap{primal_number_of_boundary_nodes};
    m_primal_boundary_node_to_dual_node_map.fill(std::make_pair(1234, 5678));
    size_t i_boundary_node = 0;
    for (FaceId primal_face_id = 0; primal_face_id < primal_number_of_faces; ++primal_face_id) {
      if (primal_face_is_owned[primal_face_id] and (primal_face_to_cell_matrix[primal_face_id].size() == 1)) {
        const auto& primal_face_to_node_list = primal_face_to_node_matrix[primal_face_id];
        for (size_t i_face_node = 0; i_face_node < primal_face_to_node_list.size(); ++i_face_node) {
          const NodeId node_id = primal_face_to_node_list[i_face_node];
          if (node_to_dual_node_correpondance[node_id] == std::numeric_limits<NodeId>::max()) {
            node_to_dual_node_correpondance[node_id]                   = dual_node_id;
            m_primal_boundary_node_to_dual_node_map[i_boundary_node++] = std::make_pair(node_id, dual_node_id++);
          }
        }
      }
    }
    Assert(i_boundary_node == primal_number_of_boundary_nodes);
  }

  dual_descriptor.node_number_vector.resize(dual_number_of_nodes);
  {
    parallel_for(m_primal_cell_to_dual_node_map.size(), [&](size_t i) {
      const auto [primal_cell_id, dual_node_id]        = m_primal_cell_to_dual_node_map[i];
      dual_descriptor.node_number_vector[dual_node_id] = primal_cell_number[primal_cell_id];
    });

    const size_t face_number_shift = max(primal_cell_number) + 1;
    parallel_for(primal_number_of_faces, [&](size_t i) {
      const auto [primal_face_id, dual_node_id]        = m_primal_face_to_dual_node_map[i];
      dual_descriptor.node_number_vector[dual_node_id] = primal_face_number[primal_face_id] + face_number_shift;
    });

    const size_t node_number_shift = face_number_shift + max(primal_face_number) + 1;
    parallel_for(m_primal_boundary_node_to_dual_node_map.size(), [&](size_t i) {
      const auto [primal_node_id, dual_node_id]        = m_primal_boundary_node_to_dual_node_map[i];
      dual_descriptor.node_number_vector[dual_node_id] = primal_node_number[primal_node_id] + node_number_shift;
    });
  }

  dual_descriptor.cell_number_vector.resize(dual_number_of_cells);
  parallel_for(dual_number_of_cells, [&](size_t i) {
    const auto [primal_node_id, dual_cell_id]        = m_primal_node_to_dual_cell_map[i];
    dual_descriptor.cell_number_vector[dual_cell_id] = primal_node_number[primal_node_id];
  });

  dual_descriptor.cell_type_vector.resize(dual_number_of_cells);

  const auto& primal_node_to_cell_matrix = primal_connectivity.nodeToCellMatrix();

  parallel_for(primal_number_of_nodes, [&](NodeId node_id) {
    const size_t i_dual_cell          = node_id;
    const auto& primal_node_cell_list = primal_node_to_cell_matrix[node_id];

    if (primal_node_cell_list.size() == 1) {
      dual_descriptor.cell_type_vector[i_dual_cell] = CellType::Quadrangle;
    } else {
      dual_descriptor.cell_type_vector[i_dual_cell] = CellType::Polygon;
    }
  });

  dual_descriptor.cell_to_node_vector.resize(dual_number_of_cells);
  const auto& primal_cell_to_face_matrix               = primal_connectivity.cellToFaceMatrix();
  const auto& primal_node_to_face_matrix               = primal_connectivity.nodeToFaceMatrix();
  const auto& primal_face_to_cell_matrix               = primal_connectivity.faceToCellMatrix();
  const auto& primal_face_to_node_matrix               = primal_connectivity.faceToNodeMatrix();
  const auto& primal_cell_face_is_reversed             = primal_connectivity.cellFaceIsReversed();
  const auto& primal_face_local_numbers_in_their_cells = primal_connectivity.faceLocalNumbersInTheirCells();

  auto next_face = [&](const CellId cell_id, const FaceId face_id, const NodeId node_id) -> FaceId {
    const auto& primal_cell_to_face_list = primal_cell_to_face_matrix[cell_id];
    for (size_t i_face = 0; i_face < primal_cell_to_face_list.size(); ++i_face) {
      const FaceId cell_face_id = primal_cell_to_face_list[i_face];
      if (cell_face_id != face_id) {
        const auto& face_node_list = primal_face_to_node_matrix[cell_face_id];
        if ((face_node_list[0] == node_id) or (face_node_list[1] == node_id)) {
          return cell_face_id;
        }
      }
    }
    // LCOV_EXCL_START
    throw UnexpectedError("could not find next face");
    // LCOV_EXCL_STOP
  };

  auto next_cell = [&](const CellId cell_id, const FaceId face_id) -> CellId {
    const auto& primal_face_to_cell_list = primal_face_to_cell_matrix[face_id];
    for (size_t i_cell = 0; i_cell < primal_face_to_cell_list.size(); ++i_cell) {
      const CellId face_cell_id = primal_face_to_cell_list[i_cell];
      if (face_cell_id != cell_id) {
        return face_cell_id;
      }
    }
    // LCOV_EXCL_START
    throw UnexpectedError("could not find next face");
    // LCOV_EXCL_STOP
  };

  parallel_for(primal_number_of_nodes, [&](NodeId node_id) {
    const size_t i_dual_cell             = node_id;
    const auto& primal_node_to_cell_list = primal_node_to_cell_matrix[node_id];
    const auto& primal_node_to_face_list = primal_node_to_face_matrix[node_id];

    auto& dual_cell_node_list = dual_descriptor.cell_to_node_vector[i_dual_cell];

    if (primal_node_to_cell_list.size() != primal_node_to_face_list.size()) {
      // boundary cell
      dual_cell_node_list.reserve(1 + primal_node_to_cell_list.size() + primal_node_to_face_list.size());

      auto [face_id, cell_id] = [&]() -> std::pair<FaceId, CellId> {
        for (size_t i_face = 0; i_face < primal_node_to_face_list.size(); ++i_face) {
          const FaceId face_id = primal_node_to_face_list[i_face];
          if (primal_face_to_cell_matrix[face_id].size() > 1) {
            continue;
          }

          const CellId cell_id        = primal_face_to_cell_matrix[face_id][0];
          const size_t i_face_in_cell = primal_face_local_numbers_in_their_cells(face_id, 0);

          if (primal_face_to_node_matrix[face_id][primal_cell_face_is_reversed(cell_id, i_face_in_cell)] == node_id) {
            return std::make_pair(face_id, cell_id);
          }
        }
        // LCOV_EXCL_START
        throw UnexpectedError("cannot find first face");
        // LCOV_EXCL_STOP
      }();

      dual_cell_node_list.push_back(m_primal_face_to_dual_node_map[face_id].second);
      dual_cell_node_list.push_back(m_primal_cell_to_dual_node_map[cell_id].second);

      face_id = next_face(cell_id, face_id, node_id);

      while (primal_face_to_cell_matrix[face_id].size() > 1) {
        dual_cell_node_list.push_back(m_primal_face_to_dual_node_map[face_id].second);
        cell_id = next_cell(cell_id, face_id);
        dual_cell_node_list.push_back(m_primal_cell_to_dual_node_map[cell_id].second);
        face_id = next_face(cell_id, face_id, node_id);
      }
      dual_cell_node_list.push_back(m_primal_face_to_dual_node_map[face_id].second);
      dual_cell_node_list.push_back(node_to_dual_node_correpondance[node_id]);

      Assert(dual_cell_node_list.size() == 1 + primal_node_to_cell_list.size() + primal_node_to_face_list.size());
    } else {
      // inner cell
      dual_cell_node_list.reserve(primal_node_to_cell_list.size() + primal_node_to_face_list.size());

      auto [face_id, cell_id] = [&]() -> std::pair<FaceId, CellId> {
        const FaceId face_id = primal_node_to_face_list[0];

        for (size_t i_face_cell = 0; i_face_cell < primal_face_to_cell_matrix[face_id].size(); ++i_face_cell) {
          const CellId cell_id        = primal_face_to_cell_matrix[face_id][i_face_cell];
          const size_t i_face_in_cell = primal_face_local_numbers_in_their_cells(face_id, i_face_cell);

          if (primal_face_to_node_matrix[face_id][primal_cell_face_is_reversed(cell_id, i_face_in_cell)] == node_id) {
            return std::make_pair(face_id, cell_id);
          }
        }
        // LCOV_EXCL_START
        throw UnexpectedError("could not find first face/cell couple");
        // LCOV_EXCL_STOP
      }();

      const FaceId first_face_id = face_id;
      do {
        dual_cell_node_list.push_back(m_primal_face_to_dual_node_map[face_id].second);
        dual_cell_node_list.push_back(m_primal_cell_to_dual_node_map[cell_id].second);

        face_id = next_face(cell_id, face_id, node_id);
        cell_id = next_cell(cell_id, face_id);
      } while (face_id != first_face_id);
    }
  });
}

template <>
void
MedianDualConnectivityBuilder::_buildConnectivityFrom<2>(const IConnectivity& i_primal_connectivity)
{
  using ConnectivityType = Connectivity<2>;

  const ConnectivityType& primal_connectivity = dynamic_cast<const ConnectivityType&>(i_primal_connectivity);

  ConnectivityDescriptor dual_descriptor;

  this->_buildConnectivityDescriptor(primal_connectivity, dual_descriptor);

  ConnectivityBuilderBase::_computeCellFaceAndFaceNodeConnectivities<2>(dual_descriptor);

  {
    const std::unordered_map<unsigned int, NodeId> primal_boundary_node_id_to_dual_node_id_map = [&] {
      std::unordered_map<unsigned int, NodeId> node_to_id_map;
      for (size_t i_node = 0; i_node < m_primal_boundary_node_to_dual_node_map.size(); ++i_node) {
        auto [primal_node_id, dual_node_id] = m_primal_boundary_node_to_dual_node_map[i_node];
        node_to_id_map[primal_node_id]      = dual_node_id;
      }
      return node_to_id_map;
    }();

    for (size_t i_node_list = 0; i_node_list < primal_connectivity.template numberOfRefItemList<ItemType::node>();
         ++i_node_list) {
      const auto& primal_ref_node_list = primal_connectivity.template refItemList<ItemType::node>(i_node_list);
      const auto& primal_node_list     = primal_ref_node_list.list();

      const std::vector<NodeId> dual_node_list = [&]() {
        std::vector<NodeId> dual_node_list;

        for (size_t i_primal_node = 0; i_primal_node < primal_node_list.size(); ++i_primal_node) {
          auto primal_node_id = primal_node_list[i_primal_node];
          const auto i_dual_node =
            primal_boundary_node_id_to_dual_node_id_map.find(primal_connectivity.nodeNumber()[primal_node_id]);
          if (i_dual_node != primal_boundary_node_id_to_dual_node_id_map.end()) {
            dual_node_list.push_back(i_dual_node->second);
          }
        }

        return dual_node_list;
      }();

      if (parallel::allReduceOr(dual_node_list.size() > 0)) {
        dual_descriptor.addRefItemList(RefNodeList{primal_ref_node_list.refId(), convert_to_array(dual_node_list),
                                                   primal_ref_node_list.isBoundary()});
      }
    }
  }

  using Face = ConnectivityFace<2>;

  const std::unordered_map<Face, FaceId, typename Face::Hash> face_to_id_map = [&] {
    std::unordered_map<Face, FaceId, typename Face::Hash> face_to_id_map;
    for (FaceId l = 0; l < dual_descriptor.face_to_node_vector.size(); ++l) {
      const auto& node_vector = dual_descriptor.face_to_node_vector[l];

      face_to_id_map[Face(node_vector, dual_descriptor.node_number_vector)] = l;
    }
    return face_to_id_map;
  }();

  for (size_t i_face_list = 0; i_face_list < primal_connectivity.template numberOfRefItemList<ItemType::face>();
       ++i_face_list) {
    const auto& primal_ref_face_list = primal_connectivity.template refItemList<ItemType::face>(i_face_list);
    const auto& primal_face_list     = primal_ref_face_list.list();

    const std::vector<FaceId> boundary_dual_face_id_list = [&]() {
      std::vector<NodeId> bounday_face_dual_node_id_list(primal_face_list.size());
      for (size_t i_face = 0; i_face < primal_face_list.size(); ++i_face) {
        bounday_face_dual_node_id_list[i_face] = m_primal_face_to_dual_node_map[primal_face_list[i_face]].second;
      }

      std::vector<bool> is_dual_node_from_boundary_face(dual_descriptor.node_number_vector.size(), false);
      for (size_t i_face = 0; i_face < bounday_face_dual_node_id_list.size(); ++i_face) {
        is_dual_node_from_boundary_face[bounday_face_dual_node_id_list[i_face]] = true;
      }

      std::vector<bool> is_dual_node_from_boundary_node(dual_descriptor.node_number_vector.size(), false);
      for (size_t i_node = 0; i_node < m_primal_boundary_node_to_dual_node_map.size(); ++i_node) {
        is_dual_node_from_boundary_node[m_primal_boundary_node_to_dual_node_map[i_node].second] = true;
      }

      std::vector<FaceId> dual_face_list;
      dual_face_list.reserve(2 * primal_face_list.size());
      for (size_t i_dual_face = 0; i_dual_face < dual_descriptor.face_to_node_vector.size(); ++i_dual_face) {
        const NodeId dual_node_0 = dual_descriptor.face_to_node_vector[i_dual_face][0];
        const NodeId dual_node_1 = dual_descriptor.face_to_node_vector[i_dual_face][1];

        if ((is_dual_node_from_boundary_face[dual_node_0] and is_dual_node_from_boundary_node[dual_node_1]) or
            (is_dual_node_from_boundary_node[dual_node_0] and is_dual_node_from_boundary_face[dual_node_1])) {
          dual_face_list.push_back(i_dual_face);
        }
      }
      return dual_face_list;
    }();

    if (parallel::allReduceOr(boundary_dual_face_id_list.size() > 0)) {
      dual_descriptor.addRefItemList(RefFaceList{primal_ref_face_list.refId(),
                                                 convert_to_array(boundary_dual_face_id_list),
                                                 primal_ref_face_list.isBoundary()});
    }
  }

  const size_t primal_number_of_nodes = primal_connectivity.numberOfNodes();
  const size_t primal_number_of_cells = primal_connectivity.numberOfCells();

  dual_descriptor.node_owner_vector.resize(dual_descriptor.node_number_vector.size());

  const auto& primal_node_owner = primal_connectivity.nodeOwner();
  for (NodeId primal_node_id = 0; primal_node_id < primal_connectivity.numberOfNodes(); ++primal_node_id) {
    dual_descriptor.node_owner_vector[primal_node_id] = primal_node_owner[primal_node_id];
  }
  const auto& primal_cell_owner = primal_connectivity.cellOwner();
  for (CellId primal_cell_id = 0; primal_cell_id < primal_number_of_cells; ++primal_cell_id) {
    dual_descriptor.node_owner_vector[primal_number_of_nodes + primal_cell_id] = primal_cell_owner[primal_cell_id];
  }

  dual_descriptor.cell_owner_vector.resize(dual_descriptor.cell_number_vector.size());
  for (NodeId primal_node_id = 0; primal_node_id < primal_number_of_nodes; ++primal_node_id) {
    dual_descriptor.cell_owner_vector[primal_node_id] = primal_node_owner[primal_node_id];
  }

  {
    std::vector<int> face_cell_owner(dual_descriptor.face_number_vector.size());
    std::fill(std::begin(face_cell_owner), std::end(face_cell_owner), parallel::size());

    for (size_t i_cell = 0; i_cell < dual_descriptor.cell_to_face_vector.size(); ++i_cell) {
      const auto& cell_face_list = dual_descriptor.cell_to_face_vector[i_cell];
      for (size_t i_face = 0; i_face < cell_face_list.size(); ++i_face) {
        const size_t face_id     = cell_face_list[i_face];
        face_cell_owner[face_id] = std::min(face_cell_owner[face_id], dual_descriptor.cell_number_vector[i_cell]);
      }
    }

    dual_descriptor.face_owner_vector.resize(face_cell_owner.size());
    for (size_t i_face = 0; i_face < face_cell_owner.size(); ++i_face) {
      dual_descriptor.face_owner_vector[i_face] = dual_descriptor.cell_owner_vector[face_cell_owner[i_face]];
    }
  }

  m_connectivity = ConnectivityType::build(dual_descriptor);

  const ConnectivityType& dual_connectivity = dynamic_cast<const ConnectivityType&>(*m_connectivity);

  m_mapper = std::make_shared<PrimalToMedianDualConnectivityDataMapper<2>>(primal_connectivity, dual_connectivity,
                                                                           m_primal_boundary_node_to_dual_node_map,
                                                                           m_primal_face_to_dual_node_map,
                                                                           m_primal_cell_to_dual_node_map,
                                                                           m_primal_node_to_dual_cell_map);
}

MedianDualConnectivityBuilder::MedianDualConnectivityBuilder(const IConnectivity& connectivity)
{
  // LCOV_EXCL_START
  if (parallel::size() > 1) {
    throw NotImplementedError("Construction of median dual mesh is not implemented in parallel");
  }
  // LCOV_EXCL_STOP
  switch (connectivity.dimension()) {
  case 2: {
    this->_buildConnectivityFrom<2>(connectivity);
    break;
  }
  case 3: {
    throw NotImplementedError("median dual connectivity");
    break;
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid connectivity dimension: " + stringify(connectivity.dimension()));
  }
    // LCOV_EXCL_STOP
  }
}
