#include <mesh/MeshNodeBoundary.hpp>

#include <Kokkos_Vector.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Messenger.hpp>

template <>
std::array<TinyVector<2>, 2>
MeshNodeBoundary<2>::_getBounds(const Mesh<Connectivity<2>>& mesh) const
{
  using R2 = TinyVector<2, double>;

  const NodeValue<const R2>& xr = mesh.xr();

  std::array<R2, 2> bounds;
  R2& xmin = bounds[0];
  R2& xmax = bounds[1];

  xmin = R2{std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
  xmax = R2{-std::numeric_limits<double>::max(), -std::numeric_limits<double>::max()};

  auto update_xmin = [](const R2& x, R2& x_min) {
    if ((x[0] < x_min[0]) or ((x[0] == x_min[0]) and (x[1] < x_min[1]))) {
      x_min = x;
    }
  };

  auto update_xmax = [](const R2& x, R2& x_max) {
    if ((x[0] > x_max[0]) or ((x[0] == x_max[0]) and (x[1] > x_max[1]))) {
      x_max = x;
    }
  };

  auto node_list = m_ref_node_list.list();
  for (size_t r = 0; r < node_list.size(); ++r) {
    const R2& x = xr[node_list[r]];
    update_xmin(x, xmin);
    update_xmax(x, xmax);
  }

  if (parallel::size() > 1) {
    Array<R2> xmin_array = parallel::allGather(xmin);
    Array<R2> xmax_array = parallel::allGather(xmax);
    for (size_t i = 0; i < xmin_array.size(); ++i) {
      update_xmin(xmin_array[i], xmin);
    }
    for (size_t i = 0; i < xmax_array.size(); ++i) {
      update_xmax(xmax_array[i], xmax);
    }
  }

  return bounds;
}

template <>
std::array<TinyVector<3>, 6>
MeshNodeBoundary<3>::_getBounds(const Mesh<Connectivity<3>>& mesh) const
{
  using R3 = TinyVector<3, double>;

  auto update_xmin = [](const R3& x, R3& xmin) {
    // XMIN: X.xmin X.ymax X.zmax
    if ((x[0] < xmin[0]) or ((x[0] == xmin[0]) and (x[1] > xmin[1])) or
        ((x[0] == xmin[0]) and (x[1] == xmin[1]) and (x[2] > xmin[2]))) {
      xmin = x;
    }
  };

  auto update_xmax = [](const R3& x, R3& xmax) {
    // XMAX: X.xmax X.ymin X.zmin
    if ((x[0] > xmax[0]) or ((x[0] == xmax[0]) and (x[1] < xmax[1])) or
        ((x[0] == xmax[0]) and (x[1] == xmax[1]) and (x[2] < xmax[2]))) {
      xmax = x;
    }
  };

  auto update_ymin = [](const R3& x, R3& ymin) {
    // YMIN: X.ymin X.zmax X.xmin
    if ((x[1] < ymin[1]) or ((x[1] == ymin[1]) and (x[2] > ymin[2])) or
        ((x[1] == ymin[1]) and (x[2] == ymin[2]) and (x[0] < ymin[0]))) {
      ymin = x;
    }
  };

  auto update_ymax = [](const R3& x, R3& ymax) {
    // YMAX: X.ymax X.zmin X.xmax
    if ((x[1] > ymax[1]) or ((x[1] == ymax[1]) and (x[2] < ymax[2])) or
        ((x[1] == ymax[1]) and (x[2] == ymax[2]) and (x[0] > ymax[0]))) {
      ymax = x;
    }
  };

  auto update_zmin = [](const R3& x, R3& zmin) {
    // ZMIN: X.zmin X.xmin X.ymin
    if ((x[2] < zmin[2]) or ((x[2] == zmin[2]) and (x[0] < zmin[0])) or
        ((x[2] == zmin[2]) and (x[0] == zmin[0]) and (x[1] < zmin[1]))) {
      zmin = x;
    }
  };

  auto update_zmax = [](const R3& x, R3& zmax) {
    // ZMAX: X.zmax X.xmax X.ymax
    if ((x[2] > zmax[2]) or ((x[2] == zmax[2]) and (x[0] > zmax[0])) or
        ((x[2] == zmax[2]) and (x[0] == zmax[0]) and (x[1] > zmax[1]))) {
      zmax = x;
    }
  };

  std::array<R3, 6> bounds;
  R3& xmin = bounds[0];
  R3& ymin = bounds[1];
  R3& zmin = bounds[2];
  R3& xmax = bounds[3];
  R3& ymax = bounds[4];
  R3& zmax = bounds[5];

  xmin = R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
  ymin = R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
  zmin = R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};

  xmax =
    -R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
  ymax =
    -R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
  zmax =
    -R3{std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};

  const NodeValue<const R3>& xr = mesh.xr();

  auto node_list = m_ref_node_list.list();
  for (size_t r = 0; r < node_list.size(); ++r) {
    const R3& x = xr[node_list[r]];
    update_xmin(x, xmin);
    update_ymin(x, ymin);
    update_zmin(x, zmin);
    update_xmax(x, xmax);
    update_ymax(x, ymax);
    update_zmax(x, zmax);
  }

  if (parallel::size() > 1) {
    Array<const R3> xmin_array = parallel::allGather(xmin);
    Array<const R3> ymin_array = parallel::allGather(ymin);
    Array<const R3> zmin_array = parallel::allGather(zmin);
    Array<const R3> xmax_array = parallel::allGather(xmax);
    Array<const R3> ymax_array = parallel::allGather(ymax);
    Array<const R3> zmax_array = parallel::allGather(zmax);

    for (size_t i = 0; i < xmin_array.size(); ++i) {
      update_xmin(xmin_array[i], xmin);
    }
    for (size_t i = 0; i < ymin_array.size(); ++i) {
      update_ymin(ymin_array[i], ymin);
    }
    for (size_t i = 0; i < zmin_array.size(); ++i) {
      update_zmin(zmin_array[i], zmin);
    }
    for (size_t i = 0; i < xmax_array.size(); ++i) {
      update_xmax(xmax_array[i], xmax);
    }
    for (size_t i = 0; i < ymax_array.size(); ++i) {
      update_ymax(ymax_array[i], ymax);
    }
    for (size_t i = 0; i < zmax_array.size(); ++i) {
      update_zmax(zmax_array[i], zmax);
    }
  }

  return bounds;
}

template <size_t Dimension>
MeshNodeBoundary<Dimension>::MeshNodeBoundary(const Mesh<Connectivity<Dimension>>& mesh,
                                              const RefFaceList& ref_face_list)
{
  const Array<const FaceId>& face_list = ref_face_list.list();
  if (ref_face_list.type() != RefItemListBase::Type::boundary) {
    std::ostringstream ost;
    ost << "invalid boundary \"" << rang::fgB::yellow << ref_face_list.refId() << rang::style::reset
        << "\": inner faces cannot be used to define mesh boundaries";
    throw NormalError(ost.str());
  }

  if constexpr (Dimension > 1) {
    Kokkos::vector<unsigned int> node_ids;
    // not enough but should reduce significantly the number of resizing
    node_ids.reserve(Dimension * face_list.size());
    const auto& face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

    for (size_t l = 0; l < face_list.size(); ++l) {
      const FaceId face_number = face_list[l];
      const auto& face_nodes   = face_to_node_matrix[face_number];

      for (size_t r = 0; r < face_nodes.size(); ++r) {
        node_ids.push_back(face_nodes[r]);
      }
    }
    std::sort(node_ids.begin(), node_ids.end());
    auto last = std::unique(node_ids.begin(), node_ids.end());
    node_ids.resize(std::distance(node_ids.begin(), last));

    Array<NodeId> node_list(node_ids.size());
    parallel_for(
      node_ids.size(), PUGS_LAMBDA(int r) { node_list[r] = node_ids[r]; });
    m_ref_node_list = RefNodeList{ref_face_list.refId(), node_list, ref_face_list.type()};
  } else {
    Array<NodeId> node_list(face_list.size());
    parallel_for(
      face_list.size(), PUGS_LAMBDA(int r) { node_list[r] = static_cast<FaceId::base_type>(face_list[r]); });
    m_ref_node_list = RefNodeList{ref_face_list.refId(), node_list, ref_face_list.type()};
  }

  // This is quite dirty but it allows a non negligible performance
  // improvement
  const_cast<Connectivity<Dimension>&>(mesh.connectivity()).addRefItemList(m_ref_node_list);
}

template <size_t Dimension>
MeshNodeBoundary<Dimension>::MeshNodeBoundary(const Mesh<Connectivity<Dimension>>& mesh,
                                              const RefEdgeList& ref_edge_list)
{
  const Array<const EdgeId>& edge_list = ref_edge_list.list();
  if (ref_edge_list.type() != RefItemListBase::Type::boundary) {
    std::ostringstream ost;
    ost << "invalid boundary \"" << rang::fgB::yellow << ref_edge_list.refId() << rang::style::reset
        << "\": inner edges cannot be used to define mesh boundaries";
    throw NormalError(ost.str());
  }

  if constexpr (Dimension > 1) {
    const auto& edge_to_node_matrix = mesh.connectivity().edgeToNodeMatrix();
    Kokkos::vector<unsigned int> node_ids;
    node_ids.reserve(2 * edge_list.size());

    for (size_t l = 0; l < edge_list.size(); ++l) {
      const EdgeId edge_number = edge_list[l];
      const auto& edge_nodes   = edge_to_node_matrix[edge_number];

      for (size_t r = 0; r < edge_nodes.size(); ++r) {
        node_ids.push_back(edge_nodes[r]);
      }
    }
    std::sort(node_ids.begin(), node_ids.end());
    auto last = std::unique(node_ids.begin(), node_ids.end());
    node_ids.resize(std::distance(node_ids.begin(), last));

    Array<NodeId> node_list(node_ids.size());
    parallel_for(
      node_ids.size(), PUGS_LAMBDA(int r) { node_list[r] = node_ids[r]; });
    m_ref_node_list = RefNodeList{ref_edge_list.refId(), node_list, ref_edge_list.type()};
  } else {
    Array<NodeId> node_list(edge_list.size());
    parallel_for(
      edge_list.size(), PUGS_LAMBDA(int r) { node_list[r] = static_cast<EdgeId::base_type>(edge_list[r]); });
    m_ref_node_list = RefNodeList{ref_edge_list.refId(), node_list, ref_edge_list.type()};
  }

  // This is quite dirty but it allows a non negligible performance
  // improvement
  const_cast<Connectivity<Dimension>&>(mesh.connectivity()).addRefItemList(m_ref_node_list);
}

template <size_t Dimension>
MeshNodeBoundary<Dimension>::MeshNodeBoundary(const Mesh<Connectivity<Dimension>>&, const RefNodeList& ref_node_list)
  : m_ref_node_list(ref_node_list)
{
  if (ref_node_list.type() != RefItemListBase::Type::boundary) {
    std::ostringstream ost;
    ost << "invalid boundary \"" << rang::fgB::yellow << this->m_ref_node_list.refId() << rang::style::reset
        << "\": inner nodes cannot be used to define mesh boundaries";
    throw NormalError(ost.str());
  }
}

template MeshNodeBoundary<1>::MeshNodeBoundary(const Mesh<Connectivity<1>>&, const RefFaceList&);
template MeshNodeBoundary<2>::MeshNodeBoundary(const Mesh<Connectivity<2>>&, const RefFaceList&);
template MeshNodeBoundary<3>::MeshNodeBoundary(const Mesh<Connectivity<3>>&, const RefFaceList&);

template MeshNodeBoundary<1>::MeshNodeBoundary(const Mesh<Connectivity<1>>&, const RefEdgeList&);
template MeshNodeBoundary<2>::MeshNodeBoundary(const Mesh<Connectivity<2>>&, const RefEdgeList&);
template MeshNodeBoundary<3>::MeshNodeBoundary(const Mesh<Connectivity<3>>&, const RefEdgeList&);

template MeshNodeBoundary<1>::MeshNodeBoundary(const Mesh<Connectivity<1>>&, const RefNodeList&);
template MeshNodeBoundary<2>::MeshNodeBoundary(const Mesh<Connectivity<2>>&, const RefNodeList&);
template MeshNodeBoundary<3>::MeshNodeBoundary(const Mesh<Connectivity<3>>&, const RefNodeList&);

template <size_t Dimension>
MeshNodeBoundary<Dimension>
getMeshNodeBoundary(const Mesh<Connectivity<Dimension>>& mesh, const IBoundaryDescriptor& boundary_descriptor)
{
  for (size_t i_ref_node_list = 0; i_ref_node_list < mesh.connectivity().template numberOfRefItemList<ItemType::node>();
       ++i_ref_node_list) {
    const auto& ref_node_list = mesh.connectivity().template refItemList<ItemType::node>(i_ref_node_list);
    const RefId& ref          = ref_node_list.refId();
    if (ref == boundary_descriptor) {
      return MeshNodeBoundary<Dimension>{mesh, ref_node_list};
    }
  }
  for (size_t i_ref_edge_list = 0; i_ref_edge_list < mesh.connectivity().template numberOfRefItemList<ItemType::edge>();
       ++i_ref_edge_list) {
    const auto& ref_edge_list = mesh.connectivity().template refItemList<ItemType::edge>(i_ref_edge_list);
    const RefId& ref          = ref_edge_list.refId();
    if (ref == boundary_descriptor) {
      return MeshNodeBoundary<Dimension>{mesh, ref_edge_list};
    }
  }
  for (size_t i_ref_face_list = 0; i_ref_face_list < mesh.connectivity().template numberOfRefItemList<ItemType::face>();
       ++i_ref_face_list) {
    const auto& ref_face_list = mesh.connectivity().template refItemList<ItemType::face>(i_ref_face_list);
    const RefId& ref          = ref_face_list.refId();
    if (ref == boundary_descriptor) {
      return MeshNodeBoundary<Dimension>{mesh, ref_face_list};
    }
  }

  std::ostringstream ost;
  ost << "cannot find node list with name " << rang::fgB::red << boundary_descriptor << rang::style::reset;

  throw NormalError(ost.str());
}

template MeshNodeBoundary<1> getMeshNodeBoundary(const Mesh<Connectivity<1>>&, const IBoundaryDescriptor&);
template MeshNodeBoundary<2> getMeshNodeBoundary(const Mesh<Connectivity<2>>&, const IBoundaryDescriptor&);
template MeshNodeBoundary<3> getMeshNodeBoundary(const Mesh<Connectivity<3>>&, const IBoundaryDescriptor&);
