#ifndef MESH_NODE_BOUNDARY_HPP
#define MESH_NODE_BOUNDARY_HPP

#include <utils/Array.hpp>

#include <algebra/TinyVector.hpp>

#include <mesh/ItemValue.hpp>
#include <mesh/RefItemList.hpp>

#include <mesh/ConnectivityMatrix.hpp>
#include <mesh/IConnectivity.hpp>

#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <Kokkos_Vector.hpp>

#include <iostream>

template <size_t Dimension>
class MeshNodeBoundary   // clazy:exclude=copyable-polymorphic
{
 protected:
  Array<const NodeId> m_node_list;

 public:
  MeshNodeBoundary& operator=(const MeshNodeBoundary&) = default;
  MeshNodeBoundary& operator=(MeshNodeBoundary&&) = default;

  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  template <typename MeshType>
  MeshNodeBoundary(const MeshType& mesh, const RefFaceList& ref_face_list)
  {
    static_assert(Dimension == MeshType::Dimension);
    const auto& face_to_cell_matrix = mesh.connectivity().faceToCellMatrix();

    const Array<const FaceId>& face_list = ref_face_list.list();
    parallel_for(
      face_list.size(), PUGS_LAMBDA(int l) {
        const auto& face_cells = face_to_cell_matrix[face_list[l]];
        if (face_cells.size() > 1) {
          throw NormalError("internal faces cannot be used to define mesh boundaries");
        }
      });

    Kokkos::vector<unsigned int> node_ids;
    // not enough but should reduce significantly the number of resizing
    node_ids.reserve(Dimension * face_list.size());
    const auto& face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

    for (size_t l = 0; l < face_list.size(); ++l) {
      const FaceId face_number = face_list[l];
      const auto& face_nodes   = face_to_node_matrix[face_number];

      for (size_t r = 0; r < face_nodes.size(); ++r) {
        node_ids.push_back(face_nodes[r]);
      }
    }
    std::sort(node_ids.begin(), node_ids.end());
    auto last = std::unique(node_ids.begin(), node_ids.end());
    node_ids.resize(std::distance(node_ids.begin(), last));

    Array<NodeId> node_list(node_ids.size());
    parallel_for(
      node_ids.size(), PUGS_LAMBDA(int r) { node_list[r] = node_ids[r]; });
    m_node_list = node_list;
  }

  template <typename MeshType>
  MeshNodeBoundary(const MeshType&, const RefNodeList& ref_node_list) : m_node_list(ref_node_list.list())
  {
    static_assert(Dimension == MeshType::Dimension);
  }

  MeshNodeBoundary()          = default;
  virtual ~MeshNodeBoundary() = default;

 protected:
  MeshNodeBoundary(const MeshNodeBoundary&) = default;
  MeshNodeBoundary(MeshNodeBoundary&&)      = default;
};

template <size_t Dimension>
class MeshFlatNodeBoundary : public MeshNodeBoundary<Dimension>   // clazy:exclude=copyable-polymorphic
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const Rd m_outgoing_normal;

  template <typename MeshType>
  PUGS_INLINE Rd _getNormal(const MeshType& mesh);

  template <typename MeshType>
  PUGS_INLINE void _checkBoundaryIsFlat(TinyVector<2, double> normal,
                                        TinyVector<2, double> xmin,
                                        TinyVector<2, double> xmax,
                                        const MeshType& mesh) const;

  template <typename MeshType>
  PUGS_INLINE Rd _getOutgoingNormal(const MeshType& mesh);

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_outgoing_normal;
  }

  MeshFlatNodeBoundary& operator=(const MeshFlatNodeBoundary&) = default;
  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&) = default;

  template <typename MeshType>
  MeshFlatNodeBoundary(const MeshType& mesh, const RefFaceList& ref_face_list)
    : MeshNodeBoundary<Dimension>(mesh, ref_face_list), m_outgoing_normal(_getOutgoingNormal(mesh))
  {
    ;
  }

  template <typename MeshType>
  MeshFlatNodeBoundary(const MeshType& mesh, const RefNodeList& ref_node_list)
    : MeshNodeBoundary<Dimension>(mesh, ref_node_list), m_outgoing_normal(_getOutgoingNormal(mesh))
  {
    ;
  }

  MeshFlatNodeBoundary()                            = default;
  MeshFlatNodeBoundary(const MeshFlatNodeBoundary&) = default;
  MeshFlatNodeBoundary(MeshFlatNodeBoundary&&)      = default;
  virtual ~MeshFlatNodeBoundary()                   = default;
};

template <>
template <typename MeshType>
void
MeshFlatNodeBoundary<2>::_checkBoundaryIsFlat(TinyVector<2, double> normal,
                                              TinyVector<2, double> xmin,
                                              TinyVector<2, double> xmax,
                                              const MeshType& mesh) const
{
  static_assert(MeshType::Dimension == 2);
  using R2 = TinyVector<2, double>;

  const R2 origin     = 0.5 * (xmin + xmax);
  const double length = l2Norm(xmax - xmin);

  const NodeValue<const R2>& xr = mesh.xr();

  parallel_for(
    m_node_list.size(), PUGS_LAMBDA(size_t r) {
      const R2& x = xr[m_node_list[r]];
      if ((x - origin, normal) > 1E-13 * length) {
        throw NormalError("this FlatBoundary is not flat!");
      }
    });
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<1, double>
MeshFlatNodeBoundary<1>::_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 1);
  using R = TinyVector<1, double>;

  const size_t number_of_bc_nodes = [&]() {
    size_t number_of_bc_nodes = 0;
    auto node_is_owned        = mesh.connectivity().nodeIsOwned();
    for (size_t i_node = 0; i_node < m_node_list.size(); ++i_node) {
      number_of_bc_nodes += (node_is_owned[m_node_list[i_node]]);
    }
    return parallel::allReduceMax(number_of_bc_nodes);
  }();

  if (number_of_bc_nodes != 1) {
    throw NormalError("Node boundaries in 1D require to have exactly 1 node");
  }

  return R{1};
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<2, double>
MeshFlatNodeBoundary<2>::_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 2);
  using R2 = TinyVector<2, double>;

  const NodeValue<const R2>& xr = mesh.xr();

  R2 xmin(std::numeric_limits<double>::max(), std::numeric_limits<double>::max());

  R2 xmax(-std::numeric_limits<double>::max(), -std::numeric_limits<double>::max());

  for (size_t r = 0; r < m_node_list.size(); ++r) {
    const R2& x = xr[m_node_list[r]];
    if ((x[0] < xmin[0]) or ((x[0] == xmin[0]) and (x[1] < xmin[1]))) {
      xmin = x;
    }
    if ((x[0] > xmax[0]) or ((x[0] == xmax[0]) and (x[1] > xmax[1]))) {
      xmax = x;
    }
  }

  Array<R2> xmin_array = parallel::allGather(xmin);
  Array<R2> xmax_array = parallel::allGather(xmax);
  for (size_t i = 0; i < xmin_array.size(); ++i) {
    const R2& x = xmin_array[i];
    if ((x[0] < xmin[0]) or ((x[0] == xmin[0]) and (x[1] < xmin[1]))) {
      xmin = x;
    }
  }
  for (size_t i = 0; i < xmax_array.size(); ++i) {
    const R2& x = xmax_array[i];
    if ((x[0] > xmax[0]) or ((x[0] == xmax[0]) and (x[1] > xmax[1]))) {
      xmax = x;
    }
  }

  if (xmin == xmax) {
    std::stringstream os;
    os << "xmin==xmax (" << xmin << "==" << xmax << ") unable to compute normal";
    throw NormalError(os.str());
  }

  R2 dx = xmax - xmin;
  dx *= 1. / l2Norm(dx);

  R2 normal(-dx[1], dx[0]);

  this->_checkBoundaryIsFlat(normal, xmin, xmax, mesh);

  return normal;
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<3, double>
MeshFlatNodeBoundary<3>::_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 3);
  using R3 = TinyVector<3, double>;

  R3 xmin(std::numeric_limits<double>::max(), std::numeric_limits<double>::max(), std::numeric_limits<double>::max());
  R3 ymin = xmin;
  R3 zmin = xmin;
  ;

  R3 xmax = -xmin;
  R3 ymax = xmax;
  R3 zmax = xmax;

  const NodeValue<const R3>& xr = mesh.xr();

  for (size_t r = 0; r < m_node_list.size(); ++r) {
    const R3& x = xr[m_node_list[r]];
    if (x[0] < xmin[0]) {
      xmin = x;
    }
    if (x[1] < ymin[1]) {
      ymin = x;
    }
    if (x[2] < zmin[2]) {
      zmin = x;
    }
    if (x[0] > xmax[0]) {
      xmax = x;
    }
    if (x[1] > ymax[1]) {
      ymax = x;
    }
    if (x[2] > zmax[2]) {
      zmax = x;
    }
  }
  Array<R3> xmin_array = parallel::allGather(xmin);
  Array<R3> xmax_array = parallel::allGather(xmax);
  Array<R3> ymin_array = parallel::allGather(ymin);
  Array<R3> ymax_array = parallel::allGather(ymax);
  Array<R3> zmin_array = parallel::allGather(zmin);
  Array<R3> zmax_array = parallel::allGather(zmax);

  for (size_t i = 0; i < xmin_array.size(); ++i) {
    const R3& x = xmin_array[i];
    if (x[0] < xmin[0]) {
      xmin = x;
    }
  }
  for (size_t i = 0; i < ymin_array.size(); ++i) {
    const R3& x = ymin_array[i];
    if (x[1] < ymin[1]) {
      ymin = x;
    }
  }
  for (size_t i = 0; i < zmin_array.size(); ++i) {
    const R3& x = zmin_array[i];
    if (x[2] < zmin[2]) {
      zmin = x;
    }
  }
  for (size_t i = 0; i < xmax_array.size(); ++i) {
    const R3& x = xmax_array[i];
    if (x[0] > xmax[0]) {
      xmax = x;
    }
  }
  for (size_t i = 0; i < ymax_array.size(); ++i) {
    const R3& x = ymax_array[i];
    if (x[1] > ymax[1]) {
      ymax = x;
    }
  }
  for (size_t i = 0; i < zmax_array.size(); ++i) {
    const R3& x = zmax_array[i];
    if (x[2] > zmax[2]) {
      zmax = x;
    }
  }

  const R3 u = xmax - xmin;
  const R3 v = ymax - ymin;
  const R3 w = zmax - zmin;

  const R3 uv        = crossProduct(u, v);
  const double uv_l2 = (uv, uv);

  R3 normal        = uv;
  double normal_l2 = uv_l2;

  const R3 uw        = crossProduct(u, w);
  const double uw_l2 = (uw, uw);

  if (uw_l2 > uv_l2) {
    normal    = uw;
    normal_l2 = uw_l2;
  }

  const R3 vw        = crossProduct(v, w);
  const double vw_l2 = (vw, vw);

  if (vw_l2 > normal_l2) {
    normal    = vw;
    normal_l2 = vw_l2;
  }

  if (normal_l2 == 0) {
    throw NormalError("cannot to compute normal!");
  }

  normal *= 1. / sqrt(normal_l2);

  // this->_checkBoundaryIsFlat(normal, xmin, xmax, mesh);

  return normal;
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<1, double>
MeshFlatNodeBoundary<1>::_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 1);
  using R = TinyVector<1, double>;

  const R normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R>& xr    = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];

    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = (xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<2, double>
MeshFlatNodeBoundary<2>::_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 2);
  using R2 = TinyVector<2, double>;

  const R2 normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R2>& xr   = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];
    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = (xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
template <typename MeshType>
PUGS_INLINE TinyVector<3, double>
MeshFlatNodeBoundary<3>::_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::Dimension == 3);
  using R3 = TinyVector<3, double>;

  const R3 normal = this->_getNormal(mesh);

  double max_height = 0;

  if (m_node_list.size() > 0) {
    const NodeValue<const R3>& xr   = mesh.xr();
    const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    const auto& node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();

    const NodeId r0      = m_node_list[0];
    const CellId j0      = node_to_cell_matrix[r0][0];
    const auto& j0_nodes = cell_to_node_matrix[j0];

    for (size_t r = 0; r < j0_nodes.size(); ++r) {
      const double height = (xr[j0_nodes[r]] - xr[r0], normal);
      if (std::abs(height) > std::abs(max_height)) {
        max_height = height;
      }
    }
  }

  Array<double> max_height_array = parallel::allGather(max_height);
  for (size_t i = 0; i < max_height_array.size(); ++i) {
    const double height = max_height_array[i];
    if (std::abs(height) > std::abs(max_height)) {
      max_height = height;
    }
  }

  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

#endif   // MESH_NODE_BOUNDARY_HPP
