#include <mesh/MeshFlatNodeBoundary.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Messenger.hpp>

template <size_t Dimension>
void
MeshFlatNodeBoundary<Dimension>::_checkBoundaryIsFlat(const TinyVector<Dimension, double>& normal,
                                                      const TinyVector<Dimension, double>& origin,
                                                      const double length,
                                                      const Mesh<Connectivity<Dimension>>& mesh) const
{
  const NodeValue<const Rd>& xr = mesh.xr();

  bool is_bad = false;

  parallel_for(this->m_node_list.size(), [=, &is_bad](int r) {
    const Rd& x = xr[this->m_node_list[r]];
    if (dot(x - origin, normal) > 1E-13 * length) {
      is_bad = true;
    }
  });

  if (parallel::allReduceOr(is_bad)) {
    std::ostringstream ost;
    ost << "invalid boundary " << rang::fgB::yellow << this->m_boundary_name << rang::style::reset
        << ": boundary is not flat!";
    throw NormalError(ost.str());
  }
}

template <>
TinyVector<1, double>
MeshFlatNodeBoundary<1>::_getNormal(const Mesh<Connectivity<1>>& mesh)
{
  using R = TinyVector<1, double>;

  const size_t number_of_bc_nodes = [&]() {
    size_t number_of_bc_nodes = 0;
    auto node_is_owned        = mesh.connectivity().nodeIsOwned();
    for (size_t i_node = 0; i_node < m_node_list.size(); ++i_node) {
      number_of_bc_nodes += (node_is_owned[m_node_list[i_node]]);
    }
    return parallel::allReduceMax(number_of_bc_nodes);
  }();

  if (number_of_bc_nodes != 1) {
    std::ostringstream ost;
    ost << "invalid boundary " << rang::fgB::yellow << m_boundary_name << rang::style::reset
        << ": node boundaries in 1D require to have exactly 1 node";
    throw NormalError(ost.str());
  }

  return R{1};
}

template <>
TinyVector<2, double>
MeshFlatNodeBoundary<2>::_getNormal(const Mesh<Connectivity<2>>& mesh)
{
  using R2 = TinyVector<2, double>;

  std::array<R2, 2> bounds = this->_getBounds(mesh);

  const R2& xmin = bounds[0];
  const R2& xmax = bounds[1];

  if (xmin == xmax) {
    std::ostringstream ost;
    ost << "invalid boundary " << rang::fgB::yellow << this->m_boundary_name << rang::style::reset
        << ": unable to compute normal";
    throw NormalError(ost.str());
  }

  R2 dx = xmax - xmin;
  dx *= 1. / l2Norm(dx);

  const R2 normal(-dx[1], dx[0]);

  this->_checkBoundaryIsFlat(normal, 0.5 * (xmin + xmax), l2Norm(xmax - xmin), mesh);

  return normal;
}

template <>
TinyVector<3, double>
MeshFlatNodeBoundary<3>::_getNormal(const Mesh<Connectivity<3>>& mesh)
{
  using R3 = TinyVector<3, double>;

  std::array<R3, 6> bounds = this->_getBounds(mesh);

  const R3& xmin = bounds[0];
  const R3& ymin = bounds[1];
  const R3& zmin = bounds[2];
  const R3& xmax = bounds[3];
  const R3& ymax = bounds[4];
  const R3& zmax = bounds[5];

  const R3 u = xmax - xmin;
  const R3 v = ymax - ymin;
  const R3 w = zmax - zmin;

  const R3 uv        = crossProduct(u, v);
  const double uv_l2 = dot(uv, uv);

  R3 normal        = uv;
  double normal_l2 = uv_l2;

  const R3 uw        = crossProduct(u, w);
  const double uw_l2 = dot(uw, uw);

  if (uw_l2 > uv_l2) {
    normal    = uw;
    normal_l2 = uw_l2;
  }

  const R3 vw        = crossProduct(v, w);
  const double vw_l2 = dot(vw, vw);

  if (vw_l2 > normal_l2) {
    normal    = vw;
    normal_l2 = vw_l2;
  }

  if (normal_l2 == 0) {
    std::ostringstream ost;
    ost << "invalid boundary " << rang::fgB::yellow << this->m_boundary_name << rang::style::reset
        << ": unable to compute normal";
    throw NormalError(ost.str());
  }

  const double length = sqrt(normal_l2);

  normal *= 1. / length;

  this->_checkBoundaryIsFlat(normal, 1. / 6. * (xmin + xmax + ymin + ymax + zmin + zmax), length, mesh);

  return normal;
}

template <>
TinyVector<1, double>
MeshFlatNodeBoundary<1>::_getOutgoingNormal(const Mesh<Connectivity<1>>& mesh)
{
  using R = TinyVector<1, double>;

  const R normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R>& xr    = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];

    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = dot(xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
TinyVector<2, double>
MeshFlatNodeBoundary<2>::_getOutgoingNormal(const Mesh<Connectivity<2>>& mesh)
{
  using R2 = TinyVector<2, double>;

  const R2 normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R2>& xr   = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];
    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = dot(xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
TinyVector<3, double>
MeshFlatNodeBoundary<3>::_getOutgoingNormal(const Mesh<Connectivity<3>>& mesh)
{
  using R3 = TinyVector<3, double>;

  const R3 normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R3>& xr   = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];

    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = dot(xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <size_t Dimension>
MeshFlatNodeBoundary<Dimension>
getMeshFlatNodeBoundary(const Mesh<Connectivity<Dimension>>& mesh, const IBoundaryDescriptor& boundary_descriptor)
{
  for (size_t i_ref_node_list = 0; i_ref_node_list < mesh.connectivity().template numberOfRefItemList<ItemType::node>();
       ++i_ref_node_list) {
    const auto& ref_node_list = mesh.connectivity().template refItemList<ItemType::node>(i_ref_node_list);
    const RefId& ref          = ref_node_list.refId();
    if (ref == boundary_descriptor) {
      return MeshFlatNodeBoundary<Dimension>{mesh, ref_node_list};
    }
  }
  for (size_t i_ref_face_list = 0; i_ref_face_list < mesh.connectivity().template numberOfRefItemList<ItemType::face>();
       ++i_ref_face_list) {
    const auto& ref_face_list = mesh.connectivity().template refItemList<ItemType::face>(i_ref_face_list);
    const RefId& ref          = ref_face_list.refId();
    if (ref == boundary_descriptor) {
      return MeshFlatNodeBoundary<Dimension>{mesh, ref_face_list};
    }
  }

  std::ostringstream ost;
  ost << "cannot find surface with name " << rang::fgB::red << boundary_descriptor << rang::style::reset;

  throw NormalError(ost.str());
}

template MeshFlatNodeBoundary<1> getMeshFlatNodeBoundary(const Mesh<Connectivity<1>>&, const IBoundaryDescriptor&);
template MeshFlatNodeBoundary<2> getMeshFlatNodeBoundary(const Mesh<Connectivity<2>>&, const IBoundaryDescriptor&);
template MeshFlatNodeBoundary<3> getMeshFlatNodeBoundary(const Mesh<Connectivity<3>>&, const IBoundaryDescriptor&);