#ifndef MESH_DATA_HPP
#define MESH_DATA_HPP

#include <Kokkos_Core.hpp>
#include <TinyVector.hpp>

#include <map>

template <typename M>
class MeshData
{
public:
  typedef M MeshType;

  static constexpr size_t dimension = MeshType::dimension;
  static_assert(dimension>0, "dimension must be strictly positive");

  typedef TinyVector<dimension> Rd;

  static constexpr double inv_dimension = 1./dimension;

private:
  const MeshType& m_mesh;
  Kokkos::View<Rd**> m_Cjr;
  Kokkos::View<double**>  m_ljr;
  Kokkos::View<Rd**> m_njr;
  Kokkos::View<Rd*>  m_xj;
  Kokkos::View<double*>   m_Vj;

  KOKKOS_INLINE_FUNCTION
  void _updateCenter()
  { // Computes vertices isobarycenter
    if(dimension == 1) {
      const Kokkos::View<const Rd*> xr = m_mesh.xr();
      const Kokkos::View<const unsigned int**>& cell_nodes
	= m_mesh.connectivity().cellNodes();
      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
	  m_xj[j] = 0.5*(xr[cell_nodes(j,0)]+xr[cell_nodes(j,1)]);
	});
    } else {
      const Kokkos::View<const Rd*> xr = m_mesh.xr();
      const Kokkos::View<const unsigned int**>& cell_nodes
	= m_mesh.connectivity().cellNodes();
      const Kokkos::View<const unsigned short*>& cell_nb_nodes
	= m_mesh.connectivity().cellNbNodes();
      const Kokkos::View<const double*>& inv_cell_nb_nodes
	= m_mesh.connectivity().invCellNbNodes();
      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
	  Rd X = zero;
	  for (int R=0; R<cell_nb_nodes[j]; ++R) {
	   X += xr[cell_nodes(j,R)];
	  }
	  m_xj[j] = inv_cell_nb_nodes[j]*X;
	});
    }
  }

  KOKKOS_INLINE_FUNCTION
  void _updateVolume()
  {
    const Kokkos::View<const unsigned int**>& cell_nodes
      = m_mesh.connectivity().cellNodes();
    const Kokkos::View<const unsigned short*> cell_nb_nodes
      = m_mesh.connectivity().cellNbNodes();

    const Kokkos::View<const Rd*> xr = m_mesh.xr();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
        double sum_cjr_xr = 0;
        for (int R=0; R<cell_nb_nodes[j]; ++R) {
          sum_cjr_xr += (xr[cell_nodes(j,R)], m_Cjr(j,R));
        }
        m_Vj[j] = inv_dimension * sum_cjr_xr;
      });
  }

  KOKKOS_INLINE_FUNCTION
  void _updateCjr() {
    if constexpr (dimension == 1) {
      // Cjr/njr/ljr are constant overtime
      }
    else if constexpr (dimension == 2) {
      const Kokkos::View<const unsigned int**>& cell_nodes
          = m_mesh.connectivity().cellNodes();
      const Kokkos::View<const unsigned short*> cell_nb_nodes
          = m_mesh.connectivity().cellNbNodes();

      const Kokkos::View<const Rd*> xr = m_mesh.xr();

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            int Rp1 = (R+1)%cell_nb_nodes[j];
            int Rm1 = (R+cell_nb_nodes[j]-1)%cell_nb_nodes[j];
            Rd half_xrp_xrm = 0.5*(xr(cell_nodes(j,Rp1))-xr(cell_nodes(j,Rm1)));
            m_Cjr(j,R) = Rd{-half_xrp_xrm[1], half_xrp_xrm[0]};
          }
        });

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            const Rd& Cjr = m_Cjr(j,R);
            m_ljr(j,R) = l2Norm(Cjr);
          }
        });

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            const Rd& Cjr = m_Cjr(j,R);
            const double inv_ljr = 1./m_ljr(j,R);
            m_njr(j,R) = inv_ljr*Cjr;
          }
        });
    } else if (dimension ==3) {
      const Kokkos::View<const unsigned int**>& cell_nodes
          = m_mesh.connectivity().cellNodes();
      const Kokkos::View<const unsigned short*> cell_nb_nodes
          = m_mesh.connectivity().cellNbNodes();

      const Kokkos::View<const Rd*> xr = m_mesh.xr();

      Kokkos::View<Rd**> Nlr("Nlr", m_mesh.connectivity().numberOfFaces(), m_mesh.connectivity().maxNbNodePerFace());

      Kokkos::parallel_for(m_mesh.numberOfFaces(), KOKKOS_LAMBDA(const int& l) {
          const auto& face_nodes = m_mesh.connectivity().m_face_to_node_matrix.rowConst(l);
          const size_t nb_nodes = face_nodes.length;
          std::vector<Rd> dxr(nb_nodes);
          for (size_t r=0; r<nb_nodes; ++r) {
            dxr[r] = xr[face_nodes((r+1)%nb_nodes)] - xr[face_nodes((r+nb_nodes-1)%nb_nodes)];
          }
          const double inv_12_nb_nodes = 1./(12.*nb_nodes);
          for (size_t r=0; r<nb_nodes; ++r) {
            Rd Nr = zero;
            const Rd two_dxr = 2*dxr[r];
            for (size_t s=0; s<nb_nodes; ++s) {
              Nr += crossProduct((two_dxr - dxr[s]), xr[face_nodes(s)]);
            }
            Nr *= inv_12_nb_nodes;
            Nr -= (1./6.)*crossProduct(dxr[r], xr[face_nodes(r)]);
            Nlr(l,r) = Nr;
          }
        });

      const Kokkos::View<const unsigned int**> cell_faces
          = m_mesh.connectivity().cellFaces();
      const Kokkos::View<const bool**> cell_faces_is_reversed
          = m_mesh.connectivity().cellFacesIsReversed();

      const Kokkos::View<const unsigned short*> cell_nb_faces
          = m_mesh.connectivity().cellNbFaces();

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            m_Cjr(j,R) = zero;
          }
          std::map<unsigned int, unsigned short> node_id_to_local;
          for (size_t R=0; R<cell_nb_nodes[j]; ++R) {
            node_id_to_local[cell_nodes(j,R)] = R;
          }
          for (size_t L=0; L<cell_nb_faces[j]; ++L) {
            const size_t l = cell_faces(j, L);
            const auto& face_nodes = m_mesh.connectivity().m_face_to_node_matrix.rowConst(l);
            if (cell_faces_is_reversed(j, L)) {
              for (size_t rl = 0; rl<face_nodes.length; ++rl) {
                m_Cjr(j, node_id_to_local[face_nodes(rl)]) -= Nlr(l,rl);
              }
            } else {
              for (size_t rl = 0; rl<face_nodes.length; ++rl) {
                m_Cjr(j, node_id_to_local[face_nodes(rl)]) += Nlr(l, rl);
              }
            }
          }
        });

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            const Rd& Cjr = m_Cjr(j,R);
            m_ljr(j,R) = l2Norm(Cjr);
          }
        });

      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
          for (int R=0; R<cell_nb_nodes[j]; ++R) {
            const Rd& Cjr = m_Cjr(j,R);
            const double inv_ljr = 1./m_ljr(j,R);
            m_njr(j,R) = inv_ljr*Cjr;
          }
        });
    }
    static_assert((dimension<=3), "only 1d, 2d and 3d are implemented");
  }

public:
  const MeshType& mesh() const
  {
    return m_mesh;
  }

  const Kokkos::View<const Rd**> Cjr() const
  {
    return m_Cjr;
  }

  const Kokkos::View<const double**> ljr() const
  {
    return m_ljr;
  }

  const Kokkos::View<const Rd**> njr() const
  {
    return m_njr;
  }

  const Kokkos::View<const Rd*> xj() const
  {
    return m_xj;
  }

  const Kokkos::View<const double*> Vj() const
  {
    return m_Vj;
  }

  void updateAllData()
  {
    this->_updateCjr();
    this->_updateCenter();
    this->_updateVolume();
  }

  MeshData(const MeshType& mesh)
    : m_mesh(mesh),
      m_Cjr("Cjr", mesh.numberOfCells(), mesh.connectivity().maxNbNodePerCell()),
      m_ljr("ljr", mesh.numberOfCells(), mesh.connectivity().maxNbNodePerCell()),
      m_njr("njr", mesh.numberOfCells(), mesh.connectivity().maxNbNodePerCell()),
      m_xj("xj", mesh.numberOfCells()),
      m_Vj("Vj", mesh.numberOfCells())
  {
    if constexpr (dimension==1) {
      // in 1d Cjr are computed once for all
      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	  m_Cjr(j,0)=-1;
	  m_Cjr(j,1)= 1;
	});
      // in 1d njr=Cjr
      m_njr=m_Cjr;
      Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	  m_ljr(j,0)= 1;
	  m_ljr(j,1)= 1;
	});
    }
    this->updateAllData();
  }

  ~MeshData()
  {
    ;
  }
};

#endif // MESH_DATA_HPP