#ifndef MESH_BOUNDARY_HPP
#define MESH_BOUNDARY_HPP

#include <Kokkos_Core.hpp>
#include <TinyVector.hpp>

#include <iostream>

#warning Must change rewrite that to compute associated node lists
template <size_t dimension>
class MeshBoundary
{
 protected:
  Kokkos::View<const unsigned int*> m_node_list;
 public:
  MeshBoundary& operator=(const MeshBoundary&) = default;
  MeshBoundary& operator=(MeshBoundary&&) = default;

  template <typename MeshType>
  MeshBoundary(const MeshType& mesh,
               const Kokkos::View<const unsigned int*>& face_list)
  {
    static_assert(dimension == MeshType::dimension);
    const Kokkos::View<const unsigned short*> face_nb_cells = mesh.connectivity().faceNbCells();

    Kokkos::parallel_for(face_list.extent(0), KOKKOS_LAMBDA(const int& l){
        if (face_nb_cells[face_list[l]]>1) {
          std::cerr << "internal faces cannot be used to define mesh boundaries\n";
          std::exit(1);
        }
      });

    const Kokkos::View<const unsigned short*> face_nb_nodes = mesh.connectivity().faceNbNodes();
    const Kokkos::View<const unsigned int**> face_nodes = mesh.connectivity().faceNodes();

    std::vector<unsigned int> node_ids;
    // not enough but should reduce significantly the number of resizing
    node_ids.reserve(dimension*face_list.extent(0));
    for (size_t l=0; l<face_list.extent(0); ++l) {
      const size_t face_number = face_list[l];
      for (size_t r=0; r<face_nb_nodes[face_number]; ++r) {
        node_ids.push_back(face_nodes(face_number,r));
      }
    }
    std::sort(node_ids.begin(), node_ids.end());
    auto last = std::unique(node_ids.begin(), node_ids.end());
    node_ids.resize(std::distance(node_ids.begin(), last));

    Kokkos::View<unsigned int*> node_list("node_list", node_ids.size());
    Kokkos::parallel_for(node_ids.size(), KOKKOS_LAMBDA(const int& r){
        node_list[r] = node_ids[r];
      });
    m_node_list = node_list;
  }

  MeshBoundary() = default;
  MeshBoundary(const MeshBoundary&) = default;
  MeshBoundary(MeshBoundary&&) = default;
  virtual ~MeshBoundary() = default;
};


template <size_t dimension>
class MeshFlatBoundary
    : public MeshBoundary<dimension>
{
  typedef TinyVector<dimension, double> Rd;
 private:
  const Rd m_outgoing_normal;

  template <typename MeshType>
  inline Rd _getOutgoingNormal(const MeshType& mesh);
 public:
  MeshFlatBoundary& operator=(const MeshFlatBoundary&) = default;
  MeshFlatBoundary& operator=(MeshFlatBoundary&&) = default;

  template <typename MeshType>
  MeshFlatBoundary(const MeshType& mesh,
                   const Kokkos::View<const unsigned int*>& face_list)
      : MeshBoundary<dimension>(mesh, face_list),
        m_outgoing_normal(_getOutgoingNormal(mesh))
  {
    ;
  }

  MeshFlatBoundary() = default;
  MeshFlatBoundary(const MeshFlatBoundary&) = default;
  MeshFlatBoundary(MeshFlatBoundary&&) = default;
  virtual ~MeshFlatBoundary() = default;

};

template <>
template <typename MeshType>
inline TinyVector<2,double>
MeshFlatBoundary<2>::
_getOutgoingNormal(const MeshType& mesh)
{
  static_assert(MeshType::dimension == 2);
  typedef TinyVector<2,double> R2;

  const Kokkos::View<const R2*> xr = mesh.xr();

  R2 xmin(std::numeric_limits<double>::max(),
          std::numeric_limits<double>::max());

  R2 xmax(-std::numeric_limits<double>::max(),
          -std::numeric_limits<double>::max());

  for (size_t r=0; r<m_node_list.extent(0); ++r) {
    const R2& x = xr[m_node_list[r]];
    if ((x[0]<xmin[0]) or ((x[0] == xmin[0]) and (x[1] < xmin[1]))) {
      xmin = x;
    }
    if ((x[0]>xmax[0]) or ((x[0] == xmax[0]) and (x[1] > xmax[1]))) {
      xmax = x;
    }
  }

  if (xmin == xmax) {
    std::cerr << "xmin==xmax (" << xmin << "==" << xmax << ") unable to compute normal";
    std::exit(1);
  }

  R2 dx = xmax-xmin;
  dx *= 1./std::sqrt((dx,dx));

  R2 normal(-dx[1], dx[0]);

  std::cout << "xmin=" << xmin << " xmax=" << xmax << " normal=" << normal << '\n';

  return normal;
}


#endif // MESH_BOUNDARY_HPP