#ifndef MESH_DATA_HPP
#define MESH_DATA_HPP

#include <Kokkos_Core.hpp>
#include <TinyVector.hpp>

#include <ItemValue.hpp>
#include <SubItemValuePerItem.hpp>

#include <map>

template <typename M>
class MeshData
{
 public:
  using  MeshType = M;

  static constexpr size_t dimension = MeshType::dimension;
  static_assert(dimension>0, "dimension must be strictly positive");

  using Rd = TinyVector<dimension>;

  static constexpr double inv_dimension = 1./dimension;

 private:
  const MeshType& m_mesh;
  NodeValuePerCell<const Rd> m_Cjr;
  NodeValuePerCell<const double> m_ljr;
  NodeValuePerCell<const Rd> m_njr;
  CellValue<const Rd> m_xj;
  CellValue<const double> m_Vj;

  PASTIS_INLINE
  void _updateCenter()
  { // Computes vertices isobarycenter
    if constexpr (dimension == 1) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const auto& cell_to_node_matrix
          = m_mesh.connectivity().cellToNodeMatrix();

      CellValue<Rd> xj(m_mesh.connectivity());
      Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j){
          const auto& cell_nodes = cell_to_node_matrix[j];
          xj[j] = 0.5*(xr[cell_nodes[0]]+xr[cell_nodes[1]]);
        });
      m_xj = xj;
    } else {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const CellValue<const double>& inv_cell_nb_nodes
          = m_mesh.connectivity().invCellNbNodes();

      const auto& cell_to_node_matrix
          = m_mesh.connectivity().cellToNodeMatrix();
      CellValue<Rd> xj(m_mesh.connectivity());
      Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j){
          Rd X = zero;
          const auto& cell_nodes = cell_to_node_matrix[j];
          for (size_t R=0; R<cell_nodes.size(); ++R) {
            X += xr[cell_nodes[R]];
          }
          xj[j] = inv_cell_nb_nodes[j]*X;
        });
      m_xj = xj;
    }
  }

  PASTIS_INLINE
  void _updateVolume()
  {
    const NodeValue<const Rd>& xr = m_mesh.xr();
    const auto& cell_to_node_matrix
        = m_mesh.connectivity().cellToNodeMatrix();

    CellValue<double> Vj(m_mesh.connectivity());
    Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j){
        double sum_cjr_xr = 0;
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t R=0; R<cell_nodes.size(); ++R) {
          sum_cjr_xr += (xr[cell_nodes[R]], m_Cjr(j,R));
        }
        Vj[j] = inv_dimension * sum_cjr_xr;
      });
    m_Vj = Vj;
  }

  PASTIS_INLINE
  void _updateCjr() {
    if constexpr (dimension == 1) {
      // Cjr/njr/ljr are constant overtime
    }
    else if constexpr (dimension == 2) {
      const NodeValue<const Rd>& xr = m_mesh.xr();
      const auto& cell_to_node_matrix
          = m_mesh.connectivity().cellToNodeMatrix();

      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j){
            const auto& cell_nodes = cell_to_node_matrix[j];
            for (size_t R=0; R<cell_nodes.size(); ++R) {
              int Rp1 = (R+1)%cell_nodes.size();
              int Rm1 = (R+cell_nodes.size()-1)%cell_nodes.size();
              Rd half_xrp_xrm = 0.5*(xr[cell_nodes[Rp1]]-xr[cell_nodes[Rm1]]);
              Cjr(j,R) = Rd{-half_xrp_xrm[1], half_xrp_xrm[0]};
            }
          });
        m_Cjr = Cjr;
      }

      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        Kokkos::parallel_for(m_Cjr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            ljr[jr] = l2Norm(m_Cjr[jr]);
          });
        m_ljr = ljr;
      }

      {
        NodeValuePerCell<Rd> njr(m_mesh.connectivity());
        Kokkos::parallel_for(m_Cjr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            njr[jr] = (1./m_ljr[jr])*m_Cjr[jr];
          });
        m_njr = njr;
      }
    } else if (dimension ==3) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      NodeValuePerFace<Rd> Nlr(m_mesh.connectivity());
      const auto& face_to_node_matrix
          = m_mesh.connectivity().faceToNodeMatrix();

      Kokkos::parallel_for(m_mesh.numberOfFaces(), PASTIS_LAMBDA(const FaceId& l) {
          const auto& face_nodes = face_to_node_matrix[l];
          const size_t nb_nodes = face_nodes.size();
          std::vector<Rd> dxr(nb_nodes);
          for (size_t r=0; r<nb_nodes; ++r) {
            dxr[r]
                = xr[face_nodes[(r+1)%nb_nodes]]
                - xr[face_nodes[(r+nb_nodes-1)%nb_nodes]];
          }
          const double inv_12_nb_nodes = 1./(12.*nb_nodes);
          for (size_t r=0; r<nb_nodes; ++r) {
            Rd Nr = zero;
            const Rd two_dxr = 2*dxr[r];
            for (size_t s=0; s<nb_nodes; ++s) {
              Nr += crossProduct((two_dxr - dxr[s]), xr[face_nodes[s]]);
            }
            Nr *= inv_12_nb_nodes;
            Nr -= (1./6.)*crossProduct(dxr[r], xr[face_nodes[r]]);
            Nlr(l,r) = Nr;
          }
        });

      const auto& cell_to_node_matrix
          = m_mesh.connectivity().cellToNodeMatrix();

      const auto& cell_to_face_matrix
          = m_mesh.connectivity().cellToFaceMatrix();

      const auto& cell_face_is_reversed = m_mesh.connectivity().cellFaceIsReversed();

      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        Kokkos::parallel_for(Cjr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            Cjr[jr] = zero;
          });

        Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j) {
            const auto& cell_nodes = cell_to_node_matrix[j];

            const auto& cell_faces = cell_to_face_matrix[j];
            const auto& face_is_reversed = cell_face_is_reversed.itemValues(j);

            for (size_t L=0; L<cell_faces.size(); ++L) {
              const FaceId& l = cell_faces[L];
              const auto& face_nodes = face_to_node_matrix[l];

#warning should this lambda be replaced by a precomputed correspondance?
              std::function local_node_number_in_cell
                  = [&](const NodeId& node_number) {
                      for (size_t i_node=0; i_node<cell_nodes.size(); ++i_node) {
                        if (node_number == cell_nodes[i_node]) {
                          return i_node;
                          break;
                        }
                      }
                      return std::numeric_limits<size_t>::max();
                    };

              if (face_is_reversed[L]) {
                for (size_t rl = 0; rl<face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Cjr(j, R) -= Nlr(l,rl);
                }
              } else {
                for (size_t rl = 0; rl<face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Cjr(j, R) += Nlr(l,rl);
                }
              }
            }
          });

        m_Cjr = Cjr;
      }

      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        Kokkos::parallel_for(m_Cjr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            ljr[jr] = l2Norm(m_Cjr[jr]);
          });
        m_ljr = ljr;
      }

      {
        NodeValuePerCell<Rd> njr(m_mesh.connectivity());
        Kokkos::parallel_for(m_Cjr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            njr[jr] = (1./m_ljr[jr])*m_Cjr[jr];
          });
        m_njr = njr;
      }
    }
    static_assert((dimension<=3), "only 1d, 2d and 3d are implemented");
  }

 public:
  const MeshType& mesh() const
  {
    return m_mesh;
  }

  const NodeValuePerCell<const Rd>& Cjr() const
  {
    return m_Cjr;
  }

  const NodeValuePerCell<const double>& ljr() const
  {
    return m_ljr;
  }

  const NodeValuePerCell<const Rd>& njr() const
  {
    return m_njr;
  }

  const CellValue<const Rd>& xj() const
  {
    return m_xj;
  }

  const CellValue<const double>& Vj() const
  {
    return m_Vj;
  }

  void updateAllData()
  {
    this->_updateCjr();
    this->_updateCenter();
    this->_updateVolume();
  }

  MeshData(const MeshType& mesh)
      : m_mesh(mesh)
  {
    if constexpr (dimension==1) {
      // in 1d Cjr are computed once for all
      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        Kokkos::parallel_for(m_mesh.numberOfCells(), PASTIS_LAMBDA(const CellId& j) {
            Cjr(j,0)=-1;
            Cjr(j,1)= 1;
          });
        m_Cjr = Cjr;
      }
      // in 1d njr=Cjr (here performs a shallow copy)
      m_njr = m_Cjr;
      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        Kokkos::parallel_for(ljr.numberOfValues(), PASTIS_LAMBDA(const size_t& jr){
            ljr[jr] = 1;
        });
        m_ljr = ljr;
      }
    }
    this->updateAllData();
  }

  ~MeshData()
  {
    ;
  }
};

#endif // MESH_DATA_HPP