#ifndef NO_SPLITTING_HPP
#define NO_SPLITTING_HPP

// --- INCLUSION fichiers headers ---

#include <Kokkos_Core.hpp>

#include <rang.hpp>
#include <BlockPerfectGas.hpp>

#include <TinyVector.hpp>
#include <TinyMatrix.hpp>
#include <Mesh.hpp>
#include <MeshData.hpp>
#include <FiniteVolumesEulerUnknowns.hpp>

// ---------------------------------

// Creation classe NoSplitting

template<typename MeshData> 
class NoSplitting 
{
  typedef typename MeshData::MeshType MeshType; 
  typedef FiniteVolumesEulerUnknowns<MeshData> UnknownsType; 

  MeshData& m_mesh_data; 
  MeshType& m_mesh;
  const typename MeshType::Connectivity& m_connectivity;

  constexpr static size_t dimension = MeshType::dimension; 

  typedef TinyVector<dimension> Rd; 
  typedef TinyMatrix<dimension> Rdd;

private:

  struct ReduceMin
  {
  private:
    const Kokkos::View<const double*> x_;

  public:
    typedef Kokkos::View<const double*>::non_const_value_type value_type;

    ReduceMin(const Kokkos::View<const double*>& x) : x_ (x) {}

    typedef Kokkos::View<const double*>::size_type size_type;
    
    KOKKOS_INLINE_FUNCTION void
    operator() (const size_type i, value_type& update) const
    {
      if (x_(i) < update) {
	update = x_(i);
      }
    }

    KOKKOS_INLINE_FUNCTION void
    join (volatile value_type& dst,
	  const volatile value_type& src) const
    {
      if (src < dst) {
	dst = src;
      }
    }

    KOKKOS_INLINE_FUNCTION void
    init (value_type& dst) const
    { // The identity under max is -Inf.
      dst= Kokkos::reduction_identity<value_type>::min();
    }
  };

  KOKKOS_INLINE_FUNCTION
  const Kokkos::View<const double*>
  computeRhoCj(const Kokkos::View<const double*>& kj,
	       const Kokkos::View<const double*>& Vj,
	       const Kokkos::View<const double*>& rhoj,
	       const Kokkos::View<const double*>& cj)
  {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	//m_rhocj[j] = rhoj[j]*cj[j];
	m_rhocj[j] = rhoj[j]*cj[j] + kj[j]/Vj[j];
      });
    return m_rhocj;
  }

  KOKKOS_INLINE_FUNCTION 
  const Kokkos::View<const Rdd**>
  computeAjr(const Kokkos::View<const double*>& rhocj,
	     const Kokkos::View<const Rd**>& Cjr) {
    const Kokkos::View<const unsigned short*> cell_nb_nodes
      = m_connectivity.cellNbNodes();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	for (int r=0; r<cell_nb_nodes[j]; ++r) {
	  m_Ajr(j,r) = tensorProduct(rhocj(j)*Cjr(j,r), Cjr(j,r));
	}
      });

    return m_Ajr;
  }

  KOKKOS_INLINE_FUNCTION 
  const Kokkos::View<const Rdd*> 
  computeAr(const Kokkos::View<const Rdd**>& Ajr) {
    const Kokkos::View<const unsigned int**> node_cells = m_connectivity.nodeCells();
    const Kokkos::View<const unsigned short**> node_cell_local_node = m_connectivity.nodeCellLocalNode();
    const Kokkos::View<const unsigned short*> node_nb_cells = m_connectivity.nodeNbCells();

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
	Rdd sum = zero;
	for (int j=0; j<node_nb_cells(r); ++j) {
	  const int J = node_cells(r,j);
	  const int R = node_cell_local_node(r,j);
	  sum += Ajr(J,R);
	}
	m_Ar(r) = sum;
      });

    return m_Ar;
  }

  KOKKOS_INLINE_FUNCTION 
  const Kokkos::View<const Rd*>
  computeBr(const Kokkos::View<const Rdd**>& Ajr,
	    const Kokkos::View<const Rd**>& Cjr,
	    const Kokkos::View<const Rd*>& uj,
	    const Kokkos::View<const double*>& pj,
	    const double t) {
    const Kokkos::View<const unsigned int**>& node_cells = m_connectivity.nodeCells();
    const Kokkos::View<const unsigned short**>& node_cell_local_node = m_connectivity.nodeCellLocalNode();
    const Kokkos::View<const unsigned short*>& node_nb_cells = m_connectivity.nodeNbCells();
    Kokkos::View<Rd*> xr = m_mesh.xr();

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
	Rd& br = m_br(r);
	br = zero;
	for (int j=0; j<node_nb_cells(r); ++j) {
	  const int J = node_cells(r,j);
	  const int R = node_cell_local_node(r,j);
	  br += Ajr(J,R)*uj(J) + pj(J)*Cjr(J,R);
	}
      });

    return m_br;
  }
  
  Kokkos::View<Rd*>  
  computeUr(const Kokkos::View<const Rdd*>& Ar,
	    const Kokkos::View<const Rd*>& br,
	    const double& t) {

    inverse(Ar, m_inv_Ar);
    const Kokkos::View<const Rdd*> invAr = m_inv_Ar;

    Kokkos::View<Rd*> xr = m_mesh.xr();
    Kokkos::View<Rd*> x0 = m_mesh.x0();
    Kokkos::View<Rd*> xmax = m_mesh.xmax();

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
	m_ur[r]=invAr(r)*br(r);
      });

    // --- CL ---

    m_ur[0]=zero;
    m_ur[m_mesh.numberOfNodes()-1]=zero;

    //m_ur[0] = x0;
    //m_ur[m_mesh.numberOfNodes()-1] = xmax[0];
    
    // CL Kidder
    /*
    double h = std::sqrt(1. - (t*t)/(50./9.));
    m_ur[0]=(-t/((50./9.)-t*t))*h*x0[0];
    m_ur[m_mesh.numberOfNodes()-1] = (-t/((50./9.)-t*t))*h*xmax[0];
    */
    // ---------
    
    return m_ur;
  }
  
  Kokkos::View<Rd**>  
  computeFjr(const Kokkos::View<const Rdd**>& Ajr,
	     const Kokkos::View<const Rd*>& ur,
	     const Kokkos::View<const Rd**>& Cjr,
	     const Kokkos::View<const Rd*>& uj,
	     const Kokkos::View<const double*>& pj) {
    const Kokkos::View<const unsigned int**>& cell_nodes = m_connectivity.cellNodes();
    const Kokkos::View<const unsigned short*> cell_nb_nodes
      = m_connectivity.cellNbNodes();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	for (int r=0; r<cell_nb_nodes[j]; ++r) {
	  m_Fjr(j,r) = Ajr(j,r)*(uj(j)-ur(cell_nodes(j,r)))+pj(j)*Cjr(j,r);
	}
      });

    return m_Fjr;
  }

  // Calcul la liste des inverses d'une liste de matrices (pour
  // l'instant seulement $R^{1\times 1}$)
  void inverse(const Kokkos::View<const Rdd*>& A,
	       Kokkos::View<Rdd*>& inv_A) const {
    Kokkos::parallel_for(A.size(), KOKKOS_LAMBDA(const int& r) {
	inv_A(r) = Rdd{1./(A(r)(0,0))};
      });
  }

  // Calcul la liste des inverses d'une liste de reels
  void inverse(const Kokkos::View<const double*>& x,
	       Kokkos::View<double*>& inv_x) const {
    Kokkos::parallel_for(x.size(), KOKKOS_LAMBDA(const int& r) {
	inv_x(r) = 1./x(r);
      });
  }

  // Enchaine les operations pour calculer les flux (Fjr et ur) pour
  // pouvoir derouler le schema
  KOKKOS_INLINE_FUNCTION
  void computeExplicitFluxes(const Kokkos::View<const Rd*>& xr,
			     const Kokkos::View<const Rd*>& xj,
			     const Kokkos::View<const double*>& rhoj,
			     const Kokkos::View<const double*>& kj,
			     const Kokkos::View<const Rd*>& uj,
			     const Kokkos::View<const double*>& pj,
			     const Kokkos::View<const double*>& cj,
			     const Kokkos::View<const double*>& Vj,
			     const Kokkos::View<const Rd**>& Cjr,
			     const double& t) {
    const Kokkos::View<const double*> rhocj  = computeRhoCj(kj, Vj, rhoj, cj);
    const Kokkos::View<const Rdd**> Ajr = computeAjr(rhocj, Cjr);

    const Kokkos::View<const Rdd*> Ar = computeAr(Ajr);
    const Kokkos::View<const Rd*> br = computeBr(Ajr, Cjr, uj, pj, t);

    Kokkos::View<Rd*> ur = m_ur;
    Kokkos::View<Rd**> Fjr = m_Fjr;
    ur = computeUr(Ar, br, t);
    Fjr = computeFjr(Ajr, ur, Cjr, uj, pj);
  }

  Kokkos::View<Rd*> m_br;
  Kokkos::View<Rdd**> m_Ajr;
  Kokkos::View<Rdd*> m_Ar;
  Kokkos::View<Rdd*> m_inv_Ar;
  Kokkos::View<Rd**> m_Fjr;
  Kokkos::View<Rd*> m_ur;
  Kokkos::View<Rd*> m_ur0;
  Kokkos::View<double*> m_rhocj;
  Kokkos::View<double*> m_Vj_over_cj;

public:
  NoSplitting(MeshData& mesh_data,
		 UnknownsType& unknowns)
    : m_mesh_data(mesh_data),
      m_mesh(mesh_data.mesh()),
      m_connectivity(m_mesh.connectivity()),
      m_br("br", m_mesh.numberOfNodes()),
      m_Ajr("Ajr", m_mesh.numberOfCells(), m_connectivity.maxNbNodePerCell()),
      m_Ar("Ar", m_mesh.numberOfNodes()),
      m_inv_Ar("inv_Ar", m_mesh.numberOfNodes()),
      m_Fjr("Fjr", m_mesh.numberOfCells(), m_connectivity.maxNbNodePerCell()),
      m_ur("ur", m_mesh.numberOfNodes()),
      m_ur0("ur0", m_mesh.numberOfNodes()),
      m_rhocj("rho_c", m_mesh.numberOfCells()),
    m_Vj_over_cj("Vj_over_cj", m_mesh.numberOfCells())
  {
    ;
  }
  
  // Calcule une evaluation du pas de temps verifiant une CFL du type
  // c*dt/dx<1 (c modifie)
  KOKKOS_INLINE_FUNCTION
  double nosplitting_dt(const Kokkos::View<const double*>& Vj,
			const Kokkos::View<const double*>& cj,
			const Kokkos::View<const double*>& rhoj,
			const Kokkos::View<const double*>& kj) const {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
	m_Vj_over_cj[j] = Vj[j]/(cj[j]+kj[j]/(rhoj[j]*Vj[j]));
      });

    double dt = std::numeric_limits<double>::max();
    Kokkos::parallel_reduce(m_mesh.numberOfCells(), ReduceMin(m_Vj_over_cj), dt);

    return dt;
  }
  

  // Avance la valeur des inconnues pendant un pas de temps dt
  void computeNextStep(const double& t, const double& dt,
		       UnknownsType& unknowns)
  {
    Kokkos::View<double*> rhoj = unknowns.rhoj();
    Kokkos::View<Rd*> uj = unknowns.uj();
    Kokkos::View<double*> Ej = unknowns.Ej();

    Kokkos::View<double*> ej = unknowns.ej();
    Kokkos::View<double*> pj = unknowns.pj();
    Kokkos::View<double*> gammaj = unknowns.gammaj();
    Kokkos::View<double*> cj = unknowns.cj();
    Kokkos::View<double*> kj = unknowns.kj();
    Kokkos::View<double*> nuj = unknowns.nuj();
    Kokkos::View<double*> PTj = unknowns.PTj();
    Kokkos::View<double*> kL = unknowns.kL();
    Kokkos::View<double*> kR = unknowns.kR();
    Kokkos::View<Rd*> uL = unknowns.uL();
    Kokkos::View<Rd*> uR = unknowns.uR();

    const Kokkos::View<const Rd*> xj = m_mesh_data.xj();
    const Kokkos::View<const double*> Vj = m_mesh_data.Vj();
    const Kokkos::View<const double*> Vl = m_mesh_data.Vl();
    const Kokkos::View<const Rd**> Cjr = m_mesh_data.Cjr();
    Kokkos::View<Rd*> xr = m_mesh.xr();

    const Kokkos::View<const unsigned int**>& cell_nodes = m_connectivity.cellNodes();

    const Kokkos::View<const unsigned short*> cell_nb_nodes
      = m_connectivity.cellNbNodes();

    const Kokkos::View<const Rd**> Fjr = m_Fjr;
    const Kokkos::View<const Rd*> ur = m_ur;
    const Kokkos::View<const Rd*> ur0 = m_ur0;

    /*
    // Calcul de PT (1er essai)
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	double sum = 0;
	double sum1 = 0;
	for (int m=0; m<cell_nb_nodes(j); ++m) {
	  sum += (uj(cell_nodes(j,m)), Cjr(cell_nodes(j,m), m));
	  sum1 += Vj(cell_nodes(j,m));
	}
	if (j == 0) {
	  //PTj(j) = pj(j) - kj(j)*(uj[j][0]-uL[0][0])/Vl(0);
	  PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	} else if (j == m_mesh.numberOfCells()-1) {
	  PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	  //PTj(j) = pj(j) - kj(j)*(uR[0][0]-uj[j][0])/Vl(m_mesh.numberOfFaces()-1);
	} else {
	  PTj(j) = pj(j) - kj(j)*2.*sum/sum1;
	}
	
      });
    */    
    /*
    // Calcul de PT (2eme essai, symetrisation) 
    const Kokkos::View<const unsigned int**>& face_cells = m_connectivity.faceCells();
    const Kokkos::View<const unsigned short*> face_nb_cells
      = m_connectivity.faceNbCells();
    const Kokkos::View<const unsigned int**>& cell_faces = m_connectivity.cellFaces();
    const Kokkos::View<const unsigned short*> cell_nb_faces
      = m_connectivity.cellNbFaces();
    
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {

	std::vector<double> stock(2);
	for (int l=0; l<cell_nb_faces(j); ++l) {
	  double sum = 0;
	  double sum2 = 0;
	  int k = cell_faces(j,l);
	  for (int i=0; i<face_nb_cells(k); ++i) {
	    int cell_here = face_cells(k,i);
	    sum += (1./Vj(cell_here))*uj[cell_here][0];
	    sum2 += 1./Vj(cell_here);
	  } 
	  stock[l] = sum/sum2;
	}
	if (j == 0) {
	  PTj(j) = pj(j) - kj(j)*(uj[j][0]-uL[0][0])/Vl(0);
	  //PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	} else if (j == m_mesh.numberOfCells()-1) {
	  //PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	  PTj(j) = pj(j) - kj(j)*(uR[0][0]-uj[j][0])/Vl(m_mesh.numberOfFaces()-1);
	} else {
	  PTj(j) = pj(j) - kj(j)*(stock[1]-stock[0])/Vj(j);
	}
	
      });
    
    
    std::ofstream fout2("pj");
    fout2.precision(15);
    for (size_t j=0; j<m_mesh.numberOfCells(); ++j) {
      fout2 << xj[j][0] << ' ' << pj[j] << '\n'; 
    } 
    std::ofstream fout3("pTj");
    fout3.precision(15);
    for (size_t j=0; j<m_mesh.numberOfCells(); ++j) {
      fout3 << xj[j][0] << ' ' << PTj[j] << '\n'; 
    } 
    */
    // Calcul de PT (3eme essai, avec uR du solveur de Riemann)
    for (int itconv=0; itconv<100; ++itconv){

    // computeExplicitFluxes(xr, xj, rhoj, kj, uj, pj, cj, Vj, Cjr, t);
    
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	
 	// if (j == 0) {
	//   PTj(j) = pj(j) - kj(j)*(uj[j][0]-uL[0][0])/Vl(0);
	//   //PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	//   //PTj(j) = pj(j) - kj(j)*(ur[cell_nodes(j,1)][0])/Vj(j);
	// } else if (j == m_mesh.numberOfCells()-1) {
	//   PTj(j) = pj(j) - kj(j)*(uR[0][0]-uj[j][0])/Vl(m_mesh.numberOfFaces()-1);
	//   //PTj(j) = pj(j) + kj(j)*(t/((50./9.)-t*t));
	//   //PTj(j) = pj(j) - kj(j)*(-ur[cell_nodes(j,0)][0])/Vj(j);
	// } else {
	  double sum = 0;
	  for (int k=0; k<cell_nb_nodes(j); ++k) {
	    int node_here = cell_nodes(j,k);
	    sum += (ur(node_here), Cjr(j,k));
	  }
	  PTj(j) = pj(j) - kj(j)*sum/Vj(j);
	// }

      });
    // jespere que ca copie, ca...
    for (int inode=0;inode<m_mesh.numberOfNodes();++inode){
      m_ur0[inode][0]=m_ur[inode][0];
    }
    // m_ur0=m_ur;
    // Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
    //   });
    
    
    // Calcule les flux
    computeExplicitFluxes(xr, xj, rhoj, kj, uj, PTj, cj, Vj, Cjr, t);
    for (int inode=0;inode<m_mesh.numberOfNodes();++inode){
      m_ur[inode][0]=0.7*m_ur[inode][0]+0.3*m_ur0[inode][0];
    }
    double sum=0.;
    for (int inode=0;inode<m_mesh.numberOfNodes();++inode){
      sum+=std::abs(m_ur0[inode][0]-m_ur[inode][0]);
    }
    sum/=double(m_mesh.numberOfNodes());
    std::cout << " it " << itconv << " sum " << sum << std::endl;
    if(sum<1.e-6) break;
    }
    // Mise a jour de la vitesse et de l'energie totale specifique
    const Kokkos::View<const double*> inv_mj = unknowns.invMj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	Rd momentum_fluxes = zero;
	double energy_fluxes = 0;
	for (int R=0; R<cell_nb_nodes[j]; ++R) {
	  const int r=cell_nodes(j,R);
	  momentum_fluxes +=  Fjr(j,R);
	  energy_fluxes   += (Fjr(j,R), ur[r]);
	}
	uj[j] -= (dt*inv_mj[j]) * momentum_fluxes;
	Ej[j] -= (dt*inv_mj[j]) * energy_fluxes;

	// ajout second membre pour kidder (k cst)
	//Ej[j] -= (dt*inv_mj[j])*Vj(j)*((kj(j)*t*t)/(((50./9.)-t*t)*((50./9.)-t*t)));
	// ajout second membre pour kidder (k = x)
	//uj[j][0] += (dt*inv_mj[j])*Vj(j)*(t/((50./9.)-t*t)); 
	//Ej[j] -= (dt*inv_mj[j])*Vj(j)*((2.*xj[j][0]*t*t)/(((50./9.)-t*t)*((50./9.)-t*t))); 
      });

    // Calcul de e par la formule e = E-0.5 u^2
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	ej[j] = Ej[j] - 0.5 * (uj[j],uj[j]);
      });

    // deplace le maillage (ses sommets) en utilisant la vitesse
    // donnee par le schema
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r){
	xr[r] += dt*ur[r];
      });

    // met a jour les quantites (geometriques) associees au maillage
    m_mesh_data.updateAllData();

    // Calcul de rho avec la formule Mj = Vj rhoj
    const Kokkos::View<const double*> mj = unknowns.mj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
	rhoj[j] = mj[j]/Vj[j];
      });
    
    // gnuplot output for vitesse
    std::ofstream fout("uj");
    fout.precision(15);
    for (size_t j=0; j<m_mesh.numberOfCells(); ++j) {
      fout << xj[j][0] << ' ' << uj[j][0] << '\n'; 
    } 
    
    // gnuplot output for vitesse riemann
    std::ofstream fout1("ur202");
    fout1.precision(15);
    for (size_t j=0; j<m_mesh.numberOfNodes(); ++j) {
      fout1 << xr[j][0] << ' ' << ur[j][0] << '\n'; 
    } 

    // Mise a jour de k
    /*
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
	kj(j) = xj[j][0];
      });
    */
    
    // stocke la vitesse pour la prochaine iteration
    // Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r){
    // 	ur0[r][0] = ur[r][0];
    //   });
    /*
    // gnuplot output for vitesse riemann
    std::ofstream fout1("ur0");
    fou.precision(15);
    for (size_t j=0; j<m_mesh.numberOfNodes(); ++j) {
      fou << xr[j][0] << ' ' << ur0[j][0] << '\n'; 
    } 
    */
  }
};



#endif // NO_SPLITTING_HPP