#ifndef MESH_NODE_BOUNDARY_HPP
#define MESH_NODE_BOUNDARY_HPP

#include <Kokkos_Core.hpp>
#include <Kokkos_Vector.hpp>
#include <TinyVector.hpp>

#include <RefNodeList.hpp>
#include <RefFaceList.hpp>

#include <iostream>

template <size_t dimension>
class MeshNodeBoundary
{
 protected:
  Kokkos::View<const unsigned int*> m_node_list;
 public:
  MeshNodeBoundary& operator=(const MeshNodeBoundary&) = default;
  MeshNodeBoundary& operator=(MeshNodeBoundary&&) = default;

  const Kokkos::View<const unsigned int*>& nodeList() const
  {
    return m_node_list;
  }

  template <typename MeshType>
  MeshNodeBoundary(const MeshType& mesh,
                   const RefFaceList& ref_face_list)
  {
    static_assert(dimension == MeshType::dimension);
    const Kokkos::View<const unsigned int*>& face_list = ref_face_list.faceList();
    Kokkos::parallel_for(face_list.extent(0), KOKKOS_LAMBDA(const int& l){
        const auto& face_cells = mesh.connectivity().m_face_to_cell_matrix.rowConst(face_list[l]);
        if (face_cells.length>1) {
          std::cerr << "internal faces cannot be used to define mesh boundaries\n";
          std::exit(1);
        }
      });

    Kokkos::vector<unsigned int> node_ids;
    // not enough but should reduce significantly the number of resizing
    node_ids.reserve(dimension*face_list.extent(0));
    for (size_t l=0; l<face_list.extent(0); ++l) {
      const size_t face_number = face_list[l];
      const auto& face_nodes = mesh.connectivity().m_face_to_node_matrix.rowConst(face_number);

      for (size_t r=0; r<face_nodes.length; ++r) {
        node_ids.push_back(face_nodes(r));
      }
    }
    std::sort(node_ids.begin(), node_ids.end());
    auto last = std::unique(node_ids.begin(), node_ids.end());
    node_ids.resize(std::distance(node_ids.begin(), last));

    Kokkos::View<unsigned int*> node_list("node_list", node_ids.size());
    Kokkos::parallel_for(node_ids.size(), KOKKOS_LAMBDA(const int& r){
        node_list[r] = node_ids[r];
      });
    m_node_list = node_list;
  }

  template <typename MeshType>
  MeshNodeBoundary(const MeshType& mesh,
                   const RefNodeList& ref_node_list)
      : m_node_list(ref_node_list.nodeList())
  {
    static_assert(dimension == MeshType::dimension);
  }

  MeshNodeBoundary() = default;
  MeshNodeBoundary(const MeshNodeBoundary&) = default;
  MeshNodeBoundary(MeshNodeBoundary&&) = default;
  virtual ~MeshNodeBoundary() = default;
};


template <size_t dimension>
class MeshFlatNodeBoundary
    : public MeshNodeBoundary<dimension>
{
  typedef TinyVector<dimension, double> Rd;
 private:
  const Rd m_outgoing_normal;

  template <typename MeshType>
  inline Rd _getNormal(const MeshType& mesh);

  template <typename MeshType>
  inline void _checkBoundaryIsFlat(const TinyVector<2,double>& normal,
                                   const TinyVector<2,double>& xmin,
                                   const TinyVector<2,double>& xmax,
                                   const MeshType& mesh) const;

  template <typename MeshType>
  inline Rd _getOutgoingNormal(const MeshType& mesh);
 public:
  const Rd& outgoingNormal() const
  {
    return m_outgoing_normal;
  }

  MeshFlatNodeBoundary& operator=(const MeshFlatNodeBoundary&) = default;
  MeshFlatNodeBoundary& operator=(MeshFlatNodeBoundary&&) = default;

  template <typename MeshType>
  MeshFlatNodeBoundary(const MeshType& mesh,
                       const RefFaceList& ref_face_list)
      : MeshNodeBoundary<dimension>(mesh, ref_face_list),
        m_outgoing_normal(_getOutgoingNormal(mesh))
  {
    ;
  }

  template <typename MeshType>
  MeshFlatNodeBoundary(const MeshType& mesh,
                       const RefNodeList& ref_node_list)
      : MeshNodeBoundary<dimension>(mesh, ref_node_list),
        m_outgoing_normal(_getOutgoingNormal(mesh))
  {
    ;
  }

  MeshFlatNodeBoundary() = default;
  MeshFlatNodeBoundary(const MeshFlatNodeBoundary&) = default;
  MeshFlatNodeBoundary(MeshFlatNodeBoundary&&) = default;
  virtual ~MeshFlatNodeBoundary() = default;
};

template<>
template <typename MeshType>
void MeshFlatNodeBoundary<2>::
_checkBoundaryIsFlat(const TinyVector<2,double>& normal,
                     const TinyVector<2,double>& xmin,
                     const TinyVector<2,double>& xmax,
                     const MeshType& mesh) const
{
  static_assert(MeshType::dimension == 2);
  typedef TinyVector<2,double> R2;
  const R2 origin = 0.5*(xmin+xmax);
  const double length = l2Norm(xmax-xmin);

  const Kokkos::View<const R2*> xr = mesh.xr();

  Kokkos::parallel_for(m_node_list.extent(0), KOKKOS_LAMBDA(const size_t& r) {
      const R2& x = xr[m_node_list[r]];
      if ((x-origin,normal)>1E-13*length) {
        std::cerr << "this FlatBoundary is not flat!\n";
        std::exit(1);
      }
    });
}

template <>
template <typename MeshType>
inline TinyVector<1,double>
MeshFlatNodeBoundary<1>::
_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 1);
  typedef TinyVector<1,double> R;

  if (m_node_list.extent(0) != 1) {
    std::cerr << "Node boundaries in 1D require to have exactly 1 node\n";
    std::exit(1);
  }

  return R(1);
}

template <>
template <typename MeshType>
inline TinyVector<2,double>
MeshFlatNodeBoundary<2>::
_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 2);
  typedef TinyVector<2,double> R2;

  const Kokkos::View<const R2*> xr = mesh.xr();

  R2 xmin(std::numeric_limits<double>::max(),
          std::numeric_limits<double>::max());

  R2 xmax(-std::numeric_limits<double>::max(),
          -std::numeric_limits<double>::max());

  for (size_t r=0; r<m_node_list.extent(0); ++r) {
    const R2& x = xr[m_node_list[r]];
    if ((x[0]<xmin[0]) or ((x[0] == xmin[0]) and (x[1] < xmin[1]))) {
      xmin = x;
    }
    if ((x[0]>xmax[0]) or ((x[0] == xmax[0]) and (x[1] > xmax[1]))) {
      xmax = x;
    }
  }

  if (xmin == xmax) {
    std::cerr << "xmin==xmax (" << xmin << "==" << xmax << ") unable to compute normal";
    std::exit(1);
  }

  R2 dx = xmax-xmin;
  dx *= 1./l2Norm(dx);

  R2 normal(-dx[1], dx[0]);

  this->_checkBoundaryIsFlat(normal, xmin, xmax, mesh);

  return normal;
}

template <>
template <typename MeshType>
inline TinyVector<3,double>
MeshFlatNodeBoundary<3>::
_getNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 3);
  typedef TinyVector<3,double> R3;


  R3 xmin(std::numeric_limits<double>::max(),
          std::numeric_limits<double>::max(),
          std::numeric_limits<double>::max());
  R3 ymin = xmin;
  R3 zmin = xmin;;

  R3 xmax =-xmin;
  R3 ymax = xmax;
  R3 zmax = xmax;

  const Kokkos::View<const R3*> xr = mesh.xr();

  for (size_t r=0; r<m_node_list.extent(0); ++r) {
    const R3& x = xr[m_node_list[r]];
    if (x[0]<xmin[0]) {
      xmin = x;
    }
    if (x[1]<ymin[1]) {
      ymin = x;
    }
    if (x[2]<zmin[2]) {
      zmin = x;
    }
    if (x[0]>xmax[0]) {
      xmax = x;
    }
    if (x[1]>ymax[1]) {
      ymax = x;
    }
    if (x[2]>zmax[2]) {
      zmax = x;
    }
  }

  const R3 u = xmax-xmin;
  const R3 v = ymax-ymin;
  const R3 w = zmax-zmin;

  const R3 uv = crossProduct(u,v);
  const double uv_l2 = (uv, uv);

  R3 normal = uv;
  double normal_l2 = uv_l2;

  const R3 uw = crossProduct(u,w);
  const double uw_l2 = (uw,uw);

  if (uw_l2 > uv_l2) {
    normal = uw;
    normal_l2 = uw_l2;
  }

  const R3 vw = crossProduct(v,w);
  const double vw_l2 = (vw,vw);

  if (vw_l2 > normal_l2) {
    normal = vw;
    normal_l2 = vw_l2;
  }


  if (normal_l2 ==0) {
    std::cerr << "not able to compute normal!\n";
    std::exit(1);
  }

  normal *= 1./sqrt(normal_l2);

#warning Add flatness test
  // this->_checkBoundaryIsFlat(normal, xmin, xmax, mesh);

  return normal;
}

template <>
template <typename MeshType>
inline TinyVector<1,double>
MeshFlatNodeBoundary<1>::
_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 1);
  typedef TinyVector<1,double> R;

  const R normal = this->_getNormal(mesh);

  const Kokkos::View<const R*>& xr = mesh.xr();

  const size_t r0 = m_node_list[0];
  const size_t j0 = mesh.connectivity().m_node_to_cell_matrix.rowConst(r0)(0);
  const auto& j0_nodes = mesh.connectivity().m_cell_to_node_matrix.rowConst(j0);
  double max_height = 0;
  for (size_t r=0; r<j0_nodes.length; ++r) {
    const double height = (xr(j0_nodes(r))-xr(r0), normal);
    if (std::abs(height) > std::abs(max_height)) {
      max_height =  height;
    }
  }
  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
template <typename MeshType>
inline TinyVector<2,double>
MeshFlatNodeBoundary<2>::
_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 2);
  typedef TinyVector<2,double> R2;

  const R2 normal = this->_getNormal(mesh);

  const Kokkos::View<const R2*>& xr = mesh.xr();

  const size_t r0 = m_node_list[0];
  const size_t j0 = mesh.connectivity().m_node_to_cell_matrix.rowConst(r0)(0);
  const auto& j0_nodes = mesh.connectivity().m_cell_to_node_matrix.rowConst(j0);
  double max_height = 0;
  for (size_t r=0; r<j0_nodes.length; ++r) {
    const double height = (xr(j0_nodes(r))-xr(r0), normal);
    if (std::abs(height) > std::abs(max_height)) {
      max_height =  height;
    }
  }
  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}

template <>
template <typename MeshType>
inline TinyVector<3,double>
MeshFlatNodeBoundary<3>::
_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 3);
  typedef TinyVector<3,double> R3;

  const R3 normal = this->_getNormal(mesh);

  const Kokkos::View<const R3*>& xr = mesh.xr();

  const size_t r0 = m_node_list[0];
  const size_t j0 = mesh.connectivity().m_node_to_cell_matrix.rowConst(r0)(0);
  const auto& j0_nodes = mesh.connectivity().m_cell_to_node_matrix.rowConst(j0);
  double max_height = 0;
  for (size_t r=0; r<j0_nodes.length; ++r) {
    const double height = (xr(j0_nodes(r))-xr(r0), normal);
    if (std::abs(height) > std::abs(max_height)) {
      max_height =  height;
    }
  }
  if (max_height > 0) {
    return -normal;
  } else {
    return normal;
  }
}


#endif // MESH_NODE_BOUNDARY_HPP