#include <mesh/DiamondDualMeshBuilder.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ConnectivityDispatcher.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/RefId.hpp>
#include <utils/Array.hpp>
#include <utils/Messenger.hpp>

template <size_t Dimension>
void
DiamondDualMeshBuilder::_buildDiamondConnectivityDescriptor(const Connectivity<Dimension>& primal_connectivity,
                                                            ConnectivityDescriptor& diamond_descriptor)
{
  const size_t primal_number_of_nodes = primal_connectivity.numberOfNodes();
  const size_t primal_number_of_cells = primal_connectivity.numberOfCells();

  const size_t diamond_number_of_nodes = primal_number_of_cells + primal_number_of_nodes;

  diamond_descriptor.node_number_vector.resize(diamond_number_of_nodes);

  const auto& primal_node_number = primal_connectivity.nodeNumber();

  for (NodeId primal_node_id = 0; primal_node_id < primal_connectivity.numberOfNodes(); ++primal_node_id) {
    diamond_descriptor.node_number_vector[primal_node_id] = primal_node_number[primal_node_id];
  }

  const auto& primal_cell_number = primal_connectivity.cellNumber();

  const size_t max_node_number = max(primal_node_number);

  for (CellId primal_cell_id = 0; primal_cell_id < primal_number_of_cells; ++primal_cell_id) {
    diamond_descriptor.node_number_vector[primal_number_of_nodes + primal_cell_id] =
      primal_cell_number[primal_cell_id] + max_node_number;
  }

  const size_t diamond_number_of_cells = primal_connectivity.numberOfFaces();
  diamond_descriptor.cell_number_vector.resize(diamond_number_of_cells);

  const auto& primal_face_number = primal_connectivity.faceNumber();

  for (FaceId i_primal_face = 0; i_primal_face < primal_connectivity.numberOfFaces(); ++i_primal_face) {
    diamond_descriptor.cell_number_vector[i_primal_face] = primal_face_number[i_primal_face];
  }

  if constexpr (Dimension == 3) {
    const size_t number_of_edges = diamond_descriptor.edge_to_node_vector.size();
    diamond_descriptor.edge_number_vector.resize(number_of_edges);
    for (size_t i_edge = 0; i_edge < number_of_edges; ++i_edge) {
      diamond_descriptor.edge_number_vector[i_edge] = i_edge;
    }
    if (parallel::size() > 1) {
      throw NotImplementedError("parallel edge numbering is undefined");
    }
  }

  diamond_descriptor.cell_to_node_vector.resize(diamond_number_of_cells);

  diamond_descriptor.cell_type_vector.resize(diamond_number_of_cells);

  const auto& primal_face_to_cell_matrix = primal_connectivity.faceToCellMatrix();

  for (FaceId i_face = 0; i_face < primal_connectivity.numberOfFaces(); ++i_face) {
    const size_t i_cell               = i_face;
    const auto& primal_face_cell_list = primal_face_to_cell_matrix[i_face];

    if constexpr (Dimension == 1) {
      throw NotImplementedError("dimension 1");
    } else if constexpr (Dimension == 2) {
      if (primal_face_cell_list.size() == 1) {
        diamond_descriptor.cell_type_vector[i_cell] = CellType::Triangle;
      } else {
        Assert(primal_face_cell_list.size() == 2);
        diamond_descriptor.cell_type_vector[i_cell] = CellType::Quadrangle;
      }
    } else {
      static_assert(Dimension == 3, "unexpected dimension");

      if (primal_face_cell_list.size() == 1) {
        diamond_descriptor.cell_type_vector[i_cell] = CellType::Pyramid;
      } else {
        Assert(primal_face_cell_list.size() == 2);
        diamond_descriptor.cell_type_vector[i_cell] = CellType::Diamond;
      }
    }
  }

  diamond_descriptor.cell_to_node_vector.resize(diamond_number_of_cells);

  const auto& primal_face_to_node_matrix              = primal_connectivity.faceToNodeMatrix();
  const auto& primal_face_local_number_in_their_cells = primal_connectivity.faceLocalNumbersInTheirCells();
  const auto& cell_face_is_reversed                   = primal_connectivity.cellFaceIsReversed();
  for (FaceId i_face = 0; i_face < primal_connectivity.numberOfFaces(); ++i_face) {
    const size_t& i_diamond_cell      = i_face;
    const auto& primal_face_cell_list = primal_face_to_cell_matrix[i_face];
    const auto& primal_face_node_list = primal_face_to_node_matrix[i_face];
    if (primal_face_cell_list.size() == 1) {
      diamond_descriptor.cell_to_node_vector[i_diamond_cell].resize(primal_face_node_list.size() + 1);

      const CellId cell_id      = primal_face_cell_list[0];
      const auto i_face_in_cell = primal_face_local_number_in_their_cells(i_face, 0);

      for (size_t i_node = 0; i_node < primal_face_node_list.size(); ++i_node) {
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][i_node] = primal_face_node_list[i_node];
      }
      diamond_descriptor.cell_to_node_vector[i_diamond_cell][primal_face_node_list.size()] =
        primal_number_of_nodes + cell_id;

      if (cell_face_is_reversed(cell_id, i_face_in_cell)) {
        if constexpr (Dimension == 2) {
          std::swap(diamond_descriptor.cell_to_node_vector[i_diamond_cell][0],
                    diamond_descriptor.cell_to_node_vector[i_diamond_cell][1]);

        } else {
          for (size_t i_node = 0; i_node < primal_face_node_list.size() / 2; ++i_node) {
            std::swap(diamond_descriptor.cell_to_node_vector[i_diamond_cell][i_node],
                      diamond_descriptor
                        .cell_to_node_vector[i_diamond_cell][primal_face_node_list.size() - 1 - i_node]);
          }
        }
      }
    } else {
      Assert(primal_face_cell_list.size() == 2);
      diamond_descriptor.cell_to_node_vector[i_diamond_cell].resize(primal_face_node_list.size() + 2);

      const CellId cell0_id     = primal_face_cell_list[0];
      const CellId cell1_id     = primal_face_cell_list[1];
      const auto i_face_in_cell = primal_face_local_number_in_their_cells(i_face, 0);

      if constexpr (Dimension == 2) {
        Assert(primal_face_node_list.size() == 2);
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][0] = primal_number_of_nodes + cell0_id;
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][1] = primal_face_node_list[0];
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][2] = primal_number_of_nodes + cell1_id;
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][3] = primal_face_node_list[1];

        if (cell_face_is_reversed(cell0_id, i_face_in_cell)) {
          std::swap(diamond_descriptor.cell_to_node_vector[i_diamond_cell][1],
                    diamond_descriptor.cell_to_node_vector[i_diamond_cell][3]);
        }
      } else {
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][0] = primal_number_of_nodes + cell0_id;
        for (size_t i_node = 0; i_node < primal_face_node_list.size(); ++i_node) {
          diamond_descriptor.cell_to_node_vector[i_diamond_cell][i_node + 1] = primal_face_node_list[i_node];
        }
        diamond_descriptor.cell_to_node_vector[i_diamond_cell][primal_face_node_list.size() + 1] =
          primal_number_of_nodes + cell1_id;

        if (cell_face_is_reversed(cell0_id, i_face_in_cell)) {
          std::swap(diamond_descriptor.cell_to_node_vector[i_diamond_cell][0],
                    diamond_descriptor.cell_to_node_vector[i_diamond_cell][primal_face_node_list.size() + 1]);
        }
      }
    }
  }
}

template <size_t Dimension>
void
DiamondDualMeshBuilder::_buildDiamondMeshFrom(const std::shared_ptr<const IMesh>& p_mesh)
{
  static_assert(Dimension <= 3, "invalid mesh dimension");
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;

  const MeshType& primal_mesh = dynamic_cast<const MeshType&>(*p_mesh);

  const ConnectivityType& primal_connectivity = primal_mesh.connectivity();

  ConnectivityDescriptor diamond_descriptor;

  this->_buildDiamondConnectivityDescriptor(primal_connectivity, diamond_descriptor);

  MeshBuilderBase::_computeCellFaceAndFaceNodeConnectivities<Dimension>(diamond_descriptor);
  if constexpr (Dimension == 3) {
    MeshBuilderBase::_computeFaceEdgeAndEdgeNodeAndCellEdgeConnectivities<Dimension>(diamond_descriptor);
  }

  {
    const std::unordered_map<unsigned int, NodeId> node_to_id_map = [&] {
      std::unordered_map<unsigned int, NodeId> node_to_id_map;
      for (size_t i_node = 0; i_node < diamond_descriptor.node_number_vector.size(); ++i_node) {
        node_to_id_map[diamond_descriptor.node_number_vector[i_node]] = i_node;
      }
      return node_to_id_map;
    }();

    for (size_t i_node_list = 0; i_node_list < primal_connectivity.template numberOfRefItemList<ItemType::node>();
         ++i_node_list) {
      const auto& primal_ref_node_list = primal_connectivity.template refItemList<ItemType::node>(i_node_list);
      std::cout << "treating " << primal_ref_node_list.refId() << '\n';
      const auto& primal_node_list = primal_ref_node_list.list();

      const std::vector<NodeId> diamond_node_list = [&]() {
        std::vector<NodeId> diamond_node_list;

        for (size_t i_primal_node = 0; i_primal_node < primal_node_list.size(); ++i_primal_node) {
          auto primal_node_id       = primal_node_list[i_primal_node];
          const auto i_diamond_node = node_to_id_map.find(primal_connectivity.nodeNumber()[primal_node_id]);
          if (i_diamond_node != node_to_id_map.end()) {
            diamond_node_list.push_back(i_diamond_node->second);
          }
        }

        return diamond_node_list;
      }();

      if (parallel::allReduceOr(diamond_node_list.size() > 0)) {
        Array<NodeId> node_array(diamond_node_list.size());
        for (size_t i = 0; i < diamond_node_list.size(); ++i) {
          node_array[i] = diamond_node_list[i];
        }
        diamond_descriptor.addRefItemList(RefNodeList{primal_ref_node_list.refId(), node_array});
      }
    }
  }

  if constexpr (Dimension > 1) {
    const auto& primal_face_to_node_matrix = primal_connectivity.faceToNodeMatrix();

    using Face = ConnectivityFace<Dimension>;

    const std::unordered_map<Face, FaceId, typename Face::Hash> face_to_id_map = [&] {
      std::unordered_map<Face, FaceId, typename Face::Hash> face_to_id_map;
      for (FaceId l = 0; l < diamond_descriptor.face_to_node_vector.size(); ++l) {
        const auto& node_vector = diamond_descriptor.face_to_node_vector[l];

        face_to_id_map[Face(node_vector, diamond_descriptor.node_number_vector)] = l;
      }
      return face_to_id_map;
    }();

    for (size_t i_face_list = 0; i_face_list < primal_connectivity.template numberOfRefItemList<ItemType::face>();
         ++i_face_list) {
      const auto& primal_ref_face_list = primal_connectivity.template refItemList<ItemType::face>(i_face_list);
      std::cout << "treating " << primal_ref_face_list.refId() << '\n';
      const auto& primal_face_list = primal_ref_face_list.list();

      const std::vector<FaceId> diamond_face_list = [&]() {
        std::vector<FaceId> diamond_face_list;
        diamond_face_list.reserve(primal_face_list.size());
        for (size_t i = 0; i < primal_face_list.size(); ++i) {
          FaceId primal_face_id = primal_face_list[i];

          const auto& primal_face_node_list = primal_face_to_node_matrix[primal_face_id];

          const auto i_diamond_face = [&]() {
            std::vector<unsigned int> node_list(primal_face_node_list.size());
            for (size_t i = 0; i < primal_face_node_list.size(); ++i) {
              node_list[i] = primal_face_node_list[i];
            }
            return face_to_id_map.find(Face(node_list, diamond_descriptor.node_number_vector));
          }();

          if (i_diamond_face != face_to_id_map.end()) {
            diamond_face_list.push_back(i_diamond_face->second);
          }
        }
        return diamond_face_list;
      }();

      if (parallel::allReduceOr(diamond_face_list.size() > 0)) {
        Array<FaceId> face_array(diamond_face_list.size());
        for (size_t i = 0; i < diamond_face_list.size(); ++i) {
          face_array[i] = diamond_face_list[i];
        }
        diamond_descriptor.addRefItemList(RefFaceList{primal_ref_face_list.refId(), face_array});
      }
    }
  }

  if constexpr (Dimension > 2) {
    const auto& primal_edge_to_node_matrix = primal_connectivity.edgeToNodeMatrix();
    using Edge                             = ConnectivityFace<2>;

    const std::unordered_map<Edge, EdgeId, typename Edge::Hash> edge_to_id_map = [&] {
      std::unordered_map<Edge, EdgeId, typename Edge::Hash> edge_to_id_map;
      for (EdgeId l = 0; l < diamond_descriptor.edge_to_node_vector.size(); ++l) {
        const auto& node_vector = diamond_descriptor.edge_to_node_vector[l];
        edge_to_id_map[Edge(node_vector, diamond_descriptor.node_number_vector)] = l;
      }
      return edge_to_id_map;
    }();

    for (size_t i_edge_list = 0; i_edge_list < primal_connectivity.template numberOfRefItemList<ItemType::edge>();
         ++i_edge_list) {
      const auto& primal_ref_edge_list = primal_connectivity.template refItemList<ItemType::edge>(i_edge_list);
      std::cout << "treating " << primal_ref_edge_list.refId() << '\n';
      const auto& primal_edge_list = primal_ref_edge_list.list();

      const std::vector<EdgeId> diamond_edge_list = [&]() {
        std::vector<EdgeId> diamond_edge_list;
        diamond_edge_list.reserve(primal_edge_list.size());
        for (size_t i = 0; i < primal_edge_list.size(); ++i) {
          EdgeId primal_edge_id = primal_edge_list[i];

          const auto& primal_edge_node_list = primal_edge_to_node_matrix[primal_edge_id];

          const auto i_diamond_edge = [&]() {
            std::vector<unsigned int> node_list(primal_edge_node_list.size());
            for (size_t i = 0; i < primal_edge_node_list.size(); ++i) {
              node_list[i] = primal_edge_node_list[i];
            }
            return edge_to_id_map.find(Edge(node_list, diamond_descriptor.node_number_vector));
          }();

          if (i_diamond_edge != edge_to_id_map.end()) {
            diamond_edge_list.push_back(i_diamond_edge->second);
          }
        }
        return diamond_edge_list;
      }();

      if (parallel::allReduceOr(diamond_edge_list.size() > 0)) {
        Array<EdgeId> edge_array(diamond_edge_list.size());
        for (size_t i = 0; i < diamond_edge_list.size(); ++i) {
          edge_array[i] = diamond_edge_list[i];
        }
        diamond_descriptor.addRefItemList(RefEdgeList{primal_ref_edge_list.refId(), edge_array});
      }
    }
  }

  const size_t primal_number_of_nodes = primal_connectivity.numberOfNodes();
  const size_t primal_number_of_cells = primal_connectivity.numberOfCells();

  diamond_descriptor.node_owner_vector.resize(diamond_descriptor.node_number_vector.size());

  const auto& primal_node_owner = primal_connectivity.nodeOwner();
  for (NodeId primal_node_id = 0; primal_node_id < primal_connectivity.numberOfNodes(); ++primal_node_id) {
    diamond_descriptor.node_owner_vector[primal_node_id] = primal_node_owner[primal_node_id];
  }
  const auto& primal_cell_owner = primal_connectivity.cellOwner();
  for (CellId primal_cell_id = 0; primal_cell_id < primal_number_of_cells; ++primal_cell_id) {
    diamond_descriptor.node_owner_vector[primal_number_of_nodes + primal_cell_id] = primal_cell_owner[primal_cell_id];
  }

  diamond_descriptor.cell_owner_vector.resize(diamond_descriptor.cell_number_vector.size());
  const auto& primal_face_owner = primal_connectivity.faceOwner();
  for (FaceId primal_face_id = 0; primal_face_id < primal_number_of_cells; ++primal_face_id) {
    diamond_descriptor.cell_owner_vector[primal_face_id] = primal_face_owner[primal_face_id];
  }

  {
    std::vector<int> face_cell_owner(diamond_descriptor.face_number_vector.size());
    std::fill(std::begin(face_cell_owner), std::end(face_cell_owner), parallel::size());

    for (size_t i_cell = 0; i_cell < diamond_descriptor.cell_to_face_vector.size(); ++i_cell) {
      const auto& cell_face_list = diamond_descriptor.cell_to_face_vector[i_cell];
      for (size_t i_face = 0; i_face < cell_face_list.size(); ++i_face) {
        const size_t face_id     = cell_face_list[i_face];
        face_cell_owner[face_id] = std::min(face_cell_owner[face_id], diamond_descriptor.cell_number_vector[i_cell]);
      }
    }

    diamond_descriptor.face_owner_vector.resize(face_cell_owner.size());
    for (size_t i_face = 0; i_face < face_cell_owner.size(); ++i_face) {
      diamond_descriptor.face_owner_vector[i_face] = diamond_descriptor.cell_owner_vector[face_cell_owner[i_face]];
    }
  }

  if constexpr (Dimension == 3) {
    std::vector<int> edge_cell_owner(diamond_descriptor.edge_number_vector.size());
    std::fill(std::begin(edge_cell_owner), std::end(edge_cell_owner), parallel::size());

    for (size_t i_cell = 0; i_cell < diamond_descriptor.cell_to_face_vector.size(); ++i_cell) {
      const auto& cell_edge_list = diamond_descriptor.cell_to_edge_vector[i_cell];
      for (size_t i_edge = 0; i_edge < cell_edge_list.size(); ++i_edge) {
        const size_t edge_id     = cell_edge_list[i_edge];
        edge_cell_owner[edge_id] = std::min(edge_cell_owner[edge_id], diamond_descriptor.cell_number_vector[i_cell]);
      }
    }

    diamond_descriptor.edge_owner_vector.resize(edge_cell_owner.size());
    for (size_t i_edge = 0; i_edge < edge_cell_owner.size(); ++i_edge) {
      diamond_descriptor.face_owner_vector[i_edge] = diamond_descriptor.cell_owner_vector[edge_cell_owner[i_edge]];
    }
  }

  std::shared_ptr p_diamond_connectivity = ConnectivityType::build(diamond_descriptor);
  ConnectivityType& diamond_connectivity = *p_diamond_connectivity;

  NodeValue<TinyVector<Dimension>> diamond_xr{diamond_connectivity};

  const auto primal_xr = primal_mesh.xr();
  MeshData<MeshType> primal_mesh_data{primal_mesh};
  const auto primal_xj = primal_mesh_data.xj();

  {
#warning define transfer functions
    NodeId i_node = 0;
    for (; i_node < primal_number_of_nodes; ++i_node) {
      diamond_xr[i_node] = primal_xr[i_node];
    }

    for (CellId i_cell = 0; i_cell < primal_number_of_cells; ++i_cell) {
      diamond_xr[i_node++] = primal_xj[i_cell];
    }
  }

  std::shared_ptr p_diamond_mesh = std::make_shared<MeshType>(p_diamond_connectivity, diamond_xr);

#warning USELESS TEST
  // -->>
  MeshData<MeshType> dual_mesh_data{*p_diamond_mesh};

  double sum = 0;
  for (CellId cell_id = 0; cell_id < p_diamond_mesh->numberOfCells(); ++cell_id) {
    sum += dual_mesh_data.Vj()[cell_id];
  }

  std::cout << "volume = " << sum << '\n';
  // <<--

  m_mesh = std::make_shared<MeshType>(p_diamond_connectivity, diamond_xr);
}

template <>
[[deprecated]] void
DiamondDualMeshBuilder::_buildDiamondMeshFrom<1>(const std::shared_ptr<const IMesh>&)
{
  m_mesh = 0;
}

DiamondDualMeshBuilder::DiamondDualMeshBuilder(const std::shared_ptr<const IMesh>& p_mesh)
{
  switch (p_mesh->dimension()) {
  case 1: {
    this->_buildDiamondMeshFrom<1>(p_mesh);
    break;
  }
  case 2: {
    this->_buildDiamondMeshFrom<2>(p_mesh);
    break;
  }
  case 3: {
    this->_buildDiamondMeshFrom<3>(p_mesh);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension: " + std::to_string(p_mesh->dimension()));
  }
  }
}
