#include <AcousticSolverWithMesh.hpp>
#include <rang.hpp>

#include <BlockPerfectGas.hpp>

typedef const double my_double;

struct AcousticSolverWithMesh::ReduceMin
{
private:
  const Kokkos::View<my_double*> x_;

public:
  typedef Kokkos::View<my_double*>::non_const_value_type value_type;

  ReduceMin(const Kokkos::View<my_double*>& x) : x_ (x) {}

  typedef Kokkos::View<my_double*>::size_type size_type;
    
  KOKKOS_INLINE_FUNCTION void
  operator() (const size_type i, value_type& update) const
  {
    if (x_(i) < update) {
      update = x_(i);
    }
  }

  KOKKOS_INLINE_FUNCTION void
  join (volatile value_type& dst,
	const volatile value_type& src) const
  {
    if (src < dst) {
      dst = src;
    }
  }

  KOKKOS_INLINE_FUNCTION void
  init (value_type& dst) const
  { // The identity under max is -Inf.
    dst= Kokkos::reduction_identity<value_type>::min();
  }
};

KOKKOS_INLINE_FUNCTION
const Kokkos::View<const double*>
AcousticSolverWithMesh::computeRhoCj(const Kokkos::View<const double*>& rhoj,
				 const Kokkos::View<const double*>& cj)
{
  Kokkos::parallel_for(m_nj, KOKKOS_LAMBDA(const int& j) {
      m_rhocj[j] = rhoj[j]*cj[j];
    });
  return m_rhocj;
}

KOKKOS_INLINE_FUNCTION
const Kokkos::View<const AcousticSolverWithMesh::Rdd*[2]>
AcousticSolverWithMesh::computeAjr(const Kokkos::View<const double*>& rhocj,
			       const Kokkos::View<const Rd*[2]>& Cjr)
{
  Kokkos::parallel_for(m_nj, KOKKOS_LAMBDA(const int& j) {
      for (int r=0; r<2; ++r) {
	m_Ajr(j,r) = tensorProduct(rhocj(j)*Cjr(j,r), Cjr(j,r));
      }
    });

  return m_Ajr;
}

KOKKOS_INLINE_FUNCTION
const Kokkos::View<const AcousticSolverWithMesh::Rdd*>
AcousticSolverWithMesh::computeAr(const Kokkos::View<const Rdd*[2]>& Ajr,
			      const Kokkos::View<const int*[2]>& node_cells,
			      const Kokkos::View<const int*[2]>& node_cell_local_node,
			      const Kokkos::View<const int*>& node_nb_cells)
{
  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r) {
      Rdd sum = zero;
      for (int j=0; j<node_nb_cells(r); ++j) {
	const int J = node_cells(r,j);
	const int R = node_cell_local_node(r,j);
  	sum += Ajr(J,R);
      }
      m_Ar(r) = sum;
    });

  return m_Ar;
}

KOKKOS_INLINE_FUNCTION
const Kokkos::View<const AcousticSolverWithMesh::Rd*>
AcousticSolverWithMesh::computeBr(const Kokkos::View<const Rdd*[2]>& Ajr,
			      const Kokkos::View<const Rd*[2]>& Cjr,
			      const Kokkos::View<const Rd*>& uj,
			      const Kokkos::View<const double*>& pj,
			      const Kokkos::View<const int*[2]>& node_cells,
			      const Kokkos::View<const int*[2]>& node_cell_local_node,
			      const Kokkos::View<const int*>& node_nb_cells)
{
  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r) {
      Rd& br = m_br(r);
      br = zero;
      for (int j=0; j<node_nb_cells(r); ++j) {
  	const int J = node_cells(r,j);
  	const int R = node_cell_local_node(r,j);
  	br += Ajr(J,R)*uj(J) + pj(J)*Cjr(J,R);
      }
    });

  return m_br;
}

KOKKOS_INLINE_FUNCTION
Kokkos::View<AcousticSolverWithMesh::Rd*>
AcousticSolverWithMesh::computeUr(const Kokkos::View<const Rdd*>& Ar,
			      const Kokkos::View<const Rd*>& br)
{
  inverse(Ar, m_inv_Ar);
  const Kokkos::View<const Rdd*> invAr = m_inv_Ar;
  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r) {
      m_ur[r]=invAr(r)*br(r);
    });
  m_ur[0]=zero;
  m_ur[m_nr-1]=zero;

  return m_ur;
}

KOKKOS_INLINE_FUNCTION
Kokkos::View<AcousticSolverWithMesh::Rd*[2]>
AcousticSolverWithMesh::computeFjr(const Kokkos::View<const Rdd*[2]>& Ajr,
			       const Kokkos::View<const Rd*>& ur,
			       const Kokkos::View<const Rd*[2]>& Cjr,
			       const Kokkos::View<const Rd*>& uj,
			       const Kokkos::View<const double*>& pj,
			       const Kokkos::View<const int*[2]>& cell_nodes)
{
  Kokkos::parallel_for(m_nj, KOKKOS_LAMBDA(const int& j) {
      for (int r=0; r<2; ++r) {
	m_Fjr(j,r) = Ajr(j,r)*(uj(j)-ur(cell_nodes(j,r)))+pj(j)*Cjr(j,r);
      }
    });

  return m_Fjr;
}

KOKKOS_INLINE_FUNCTION
double AcousticSolverWithMesh::
acoustic_dt(const Kokkos::View<const double*>& Vj,
	    const Kokkos::View<const double*>& cj) const
{
  Kokkos::parallel_for(m_nj, KOKKOS_LAMBDA(const int& j){
      m_Vj_over_cj[j] = Vj[j]/cj[j];
    });

  double dt = std::numeric_limits<double>::max();
  Kokkos::parallel_reduce(m_nj, ReduceMin(m_Vj_over_cj), dt);

  return dt;
}

KOKKOS_INLINE_FUNCTION
void
AcousticSolverWithMesh::inverse(const Kokkos::View<const double*>& x,
			    Kokkos::View<double*>& inv_x) const
{
  Kokkos::parallel_for(x.size(), KOKKOS_LAMBDA(const int& r) {
      inv_x(r) = 1./x(r);
    });
}

KOKKOS_INLINE_FUNCTION
void
AcousticSolverWithMesh::inverse(const Kokkos::View<const Rdd*>& A,
			    Kokkos::View<Rdd*>& inv_A) const
{
  Kokkos::parallel_for(A.size(), KOKKOS_LAMBDA(const int& r) {
      inv_A(r) = Rdd{1./(A(r)(0,0))};
    });
}


KOKKOS_INLINE_FUNCTION
void AcousticSolverWithMesh::computeExplicitFluxes(const Kokkos::View<const Rd*>& xr,
					       const Kokkos::View<const Rd*>& xj,
					       const Kokkos::View<const double*>& rhoj,
					       const Kokkos::View<const Rd*>& uj,
					       const Kokkos::View<const double*>& pj,
					       const Kokkos::View<const double*>& cj,
					       const Kokkos::View<const double*>& Vj,
					       const Kokkos::View<const Rd*[2]>& Cjr,
					       const Kokkos::View<const int*[2]>& cell_nodes,
					       const Kokkos::View<const int*[2]>& node_cells,
					       const Kokkos::View<const int*>& node_nb_cells,
					       const Kokkos::View<const int*[2]>& node_cell_local_node,
					       Kokkos::View<Rd*>& ur,
					       Kokkos::View<Rd*[2]>& Fjr)
{
  const Kokkos::View<const double*> rhocj  = computeRhoCj(rhoj, cj);
  const Kokkos::View<const Rdd*[2]> Ajr = computeAjr(rhocj, Cjr);

  const Kokkos::View<const Rdd*> Ar = computeAr(Ajr, node_cells, node_cell_local_node, node_nb_cells);
  const Kokkos::View<const Rd*> br = computeBr(Ajr, Cjr, uj, pj,
						   node_cells, node_cell_local_node, node_nb_cells);

  ur  = computeUr(Ar, br);
  Fjr = computeFjr(Ajr, ur, Cjr, uj, pj, cell_nodes);
}

AcousticSolverWithMesh::AcousticSolverWithMesh(const long int& nj)
  : m_nj(nj),
    m_nr(nj+1),
    m_br("br", m_nr),
    m_Ajr("Ajr", m_nj),
    m_Ar("Ar", m_nr),
    m_inv_Ar("inv_Ar", m_nr),
    m_Fjr("Fjr", m_nj),
    m_ur("ur", m_nr),
    m_rhocj("rho_c", m_nj),
    m_Vj_over_cj("Vj_over_cj", m_nj)
{
  Kokkos::View<Rd*> xj("xj",m_nj);
  Kokkos::View<double*> rhoj("rhoj",m_nj);

  Kokkos::View<Rd*> uj("uj",m_nj);

  Kokkos::View<double*> Ej("Ej",m_nj);
  Kokkos::View<double*> ej("ej",m_nj);
  Kokkos::View<double*> pj("pj",m_nj);
  Kokkos::View<double*> Vj("Vj",m_nj);
  Kokkos::View<double*> gammaj("gammaj",m_nj);
  Kokkos::View<double*> cj("cj",m_nj);
  Kokkos::View<double*> mj("mj",m_nj);

  Kokkos::View<Rd*>  xr("xr", m_nr);

  const double delta_x = 1./m_nj;

  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r){
      xr[r][0] = r*delta_x;
    });

  Kokkos::View<int*[2]> cell_nodes("cell_nodes",m_nj,2);
  Kokkos::View<int*[2]> node_cells("node_cells",m_nr,2);
  Kokkos::View<int*[2]> node_cell_local_node("node_cells",m_nr,2);
  Kokkos::View<int*> node_nb_cells("node_cells",m_nr);

  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r){
      node_nb_cells(r) = 2;
    });
  node_nb_cells(0) = 1;
  node_nb_cells(m_nr-1) = 1;

  node_cells(0,0) = 0;
  Kokkos::parallel_for(m_nr-2, KOKKOS_LAMBDA(const int& r){
      node_cells(r+1,0) = r;
      node_cells(r+1,1) = r+1;
    });
  node_cells(m_nr-1,0) =m_nj-1;

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      cell_nodes(j,0) = j;
      cell_nodes(j,1) = j+1;
    });

  Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r){
      for (int J=0; J<node_nb_cells(r); ++J) {
	int j = node_cells(r,J);
	for (int R=0; R<2; ++R) {
	  if (cell_nodes(j,R) == r) {
	    node_cell_local_node(r,J) = R;
	  }
	}
      }
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      xj[j] = 0.5*(xr[cell_nodes(j,0)]+xr[cell_nodes(j,1)]);
    });

  Kokkos::View<Rd*[2]> Cjr("Cjr",m_nj);
  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j) {
      Cjr(j,0)=-1;
      Cjr(j,1)= 1;
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      Vj[j] = (xr[cell_nodes(j,1)], Cjr(j,1))
	    + (xr[cell_nodes(j,0)], Cjr(j,0));
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      if (xj[j][0]<0.5) {
	rhoj[j]=1;
      } else {
	rhoj[j]=0.125;
      }
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      if (xj[j][0]<0.5) {
	pj[j]=1;
      } else {
	pj[j]=0.1;
      }
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      uj[j] = zero;
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      gammaj[j] = 1.4;
    });

  BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);

  block_eos.updateEandCFromRhoP();

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      Ej[j] = ej[j]+0.5*(uj[j],uj[j]);
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      mj[j] = rhoj[j] * Vj[j];
    });

  Kokkos::View<double*> inv_mj("inv_mj",m_nj);
  inverse(mj, inv_mj);

  const double tmax=0.2;
  double t=0;

  int itermax=std::numeric_limits<int>::max();
  int iteration=0;

  while((t<tmax) and (iteration<itermax)) {
    double dt = 0.4*acoustic_dt(Vj, cj);
    if (t+dt<tmax) {
      t+=dt;
    } else {
      dt=tmax-t;
      t=tmax;
    }

    computeExplicitFluxes(xr, xj,
			  rhoj, uj, pj, cj, Vj, Cjr,
			  cell_nodes, node_cells, node_nb_cells, node_cell_local_node,
			  m_ur, m_Fjr);

    const Kokkos::View<const Rd*[2]> Fjr = m_Fjr;
    const Kokkos::View<const Rd*> ur = m_ur;

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j) {
	Rd momentum_fluxes = zero;
	double energy_fluxes = 0;
	for (int R=0; R<2; ++R) {
	  const int r=cell_nodes(j,R);
	  momentum_fluxes +=  Fjr(j,R);
	  energy_fluxes   += ((Fjr(j,R), ur[r]));
	}
	uj[j] -= (dt*inv_mj[j]) * momentum_fluxes;
	Ej[j] -= (dt*inv_mj[j]) * energy_fluxes;
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j) {
	ej[j] = Ej[j] - 0.5 * (uj[j],uj[j]);
      });

    Kokkos::parallel_for(m_nr, KOKKOS_LAMBDA(const int& r){
	xr[r] += dt*ur[r];
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
	xj[j] = 0.5*(xr[j]+xr[j+1]);
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      Vj[j] = (xr[cell_nodes(j,1)], Cjr(j,1))
	    + (xr[cell_nodes(j,0)], Cjr(j,0));
    });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
	rhoj[j] = mj[j]/Vj[j];
      });

    block_eos.updatePandCFromRhoE();    
    
    ++iteration;
  }

  // {
  //   std::ofstream fout("rho");
  //   for (int j=0; j<nj; ++j) {
  //     fout << xj[j][0] << ' ' << rhoj[j] << '\n';
  //   }
  // }

  std::cout << "* " << rang::style::underline << "Final time" << rang::style::reset
	    << ":  " << rang::fgB::green << t << rang::fg::reset << " (" << iteration << " iterations)\n";

}