#include <iostream>
#include <Kokkos_Core.hpp>
#include <RevisionInfo.hpp>
#include <rang.hpp>

#include <CLI/CLI.hpp>

inline double e(double rho, double p, double gamma)
{
  return p/(rho*(gamma-1));
}

inline double p(double rho, double e, double gamma)
{
  return (gamma-1)*rho*e;
}

typedef const double my_double;

struct ReduceMin {
private:
  const Kokkos::View<my_double*> x_;

public:
  typedef Kokkos::View<my_double*>::non_const_value_type value_type;

  ReduceMin(const Kokkos::View<my_double*>& x) : x_ (x) {}

  typedef Kokkos::View<my_double*>::size_type size_type;
    
  KOKKOS_INLINE_FUNCTION void
  operator() (const size_type i, value_type& update) const
  {
    if (x_(i) < update) {
      update = x_(i);
    }
  }

  KOKKOS_INLINE_FUNCTION void
  join (volatile value_type& dst,
	const volatile value_type& src) const
  {
    if (src < dst) {
      dst = src;
    }
  }

  KOKKOS_INLINE_FUNCTION void
  init (value_type& dst) const
  { // The identity under max is -Inf.
    dst= Kokkos::reduction_identity<value_type>::min();
  }
};
    

double acoustic_dt(const Kokkos::View<double*>& Vj,
		   const Kokkos::View<double*>& cj)
{
  const size_t nj = Vj.size();
  double dt = std::numeric_limits<double>::max();

  Kokkos::View<double*> Vj_cj("Vj_cj", nj);

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      Vj_cj[j] = Vj[j]/cj[j];
    });

  Kokkos::parallel_reduce(nj, ReduceMin(Vj_cj), dt);

  // Kokkos::parallel_reduce(n, KOKKOS_LAMBDA(const long i, long& lcount) {
  //   lcount += (i % 2) == 0;
  // }, count);
  return dt;
}


void computeExplicitFluxes(const Kokkos::View<double*>& xr,
			   const Kokkos::View<double*>& xj,
			   const Kokkos::View<double*>& rhoj,
			   const Kokkos::View<double*>& uj,
			   const Kokkos::View<double*>& pj,
			   const Kokkos::View<double*>& cj,
			   const Kokkos::View<double*>& Vj,
			   Kokkos::View<double*>& ur,
			   Kokkos::View<double*>& pr)
{
  // calcul de ur
  ur[0]=0;
  const size_t nr = ur.size();
  const size_t nj = uj.size();

  Kokkos::parallel_for(nj-1, KOKKOS_LAMBDA(const int& j) {
    const int r = j+1;
    const int k = r;
    const double ujr = uj[j];
    const double ukr = uj[k];
    const double pjr = pj[j];
    const double pkr = pj[k];

    ur[r]=(rhoj[j]*cj[j]*ujr + rhoj[k]*cj[k]*ukr + pjr-pkr)/(rhoj[j]*cj[j]+rhoj[k]*cj[k]);
    });
  ur[nr-1]=0;

  // calcul de pr
  pr[0] = pj[0] + rhoj[0]*cj[0]*(ur[0] - uj[0]);
  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j) {
    const int r = j+1;

    const double ujr = uj[j];
    const double pjr = pj[j];

    pr[r]=pjr+rhoj[j]*cj[j]*(ujr-ur[r]);
    });
}


int main(int argc, char *argv[])
{
  CLI::App app{"Pastis help"};

  long number = 1000;
  app.add_option("-n,--number", number, "A big integer");
  int threads=-1;
  app.add_option("--threads", threads, "Number of Kokkos threads");

  CLI11_PARSE(app, argc, argv);

  std::cout << "Code version: "
	    << rang::style::bold << RevisionInfo::version() << rang::style::reset << '\n';

  std::cout << "-------------------- "
	    << rang::fg::green
	    << "git info"
	    << rang::fg::reset
	    <<" -------------------------"
	    << '\n';
  std::cout << "tag:  " << rang::fg::reset
	    << rang::style::bold << RevisionInfo::gitTag() << rang::style::reset << '\n';
  std::cout << "HEAD: " << rang::style::bold << RevisionInfo::gitHead() << rang::style::reset << '\n';
  std::cout << "hash: " << rang::style::bold << RevisionInfo::gitHash() << rang::style::reset << "  (";
  if (RevisionInfo::gitIsClean()) {
    std::cout << rang::fgB::green << "clean" << rang::fg::reset;
  } else {
    std::cout << rang::fgB::red << "dirty" << rang::fg::reset; 
  }
  std::cout << ")\n";
  std::cout << "-------------------------------------------------------\n";

  Kokkos::initialize(argc, argv);
  Kokkos::DefaultExecutionSpace::print_configuration(std::cout);

  const long& nj=number; 

  Kokkos::View<double*> xj("xj", nj);
  Kokkos::View<double*> rhoj("rhoj", nj);

  Kokkos::View<double*> uj("uj", nj);

  Kokkos::View<double*> Ej("Ej", nj);
  Kokkos::View<double*> ej("ej", nj);
  Kokkos::View<double*> pj("pj", nj);
  Kokkos::View<double*> Vj("Vj", nj);
  Kokkos::View<double*> gammaj("gammaj", nj);
  Kokkos::View<double*> cj("cj", nj);
  Kokkos::View<double*> mj("mj", nj);
  Kokkos::View<double*> inv_mj("inv_mj", nj);

  const int nr=nj+1;

  Kokkos::View<double*>  xr("xr", nr);

  const double delta_x = 1./nj;
  Kokkos::Timer timer;
  timer.reset();

  Kokkos::parallel_for(nr, KOKKOS_LAMBDA(const int& r){
      xr[r] = r*delta_x;
    });


  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      xj[j] = 0.5*(xr[j]+xr[j+1]);
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      Vj[j] = xr[j+1]-xr[j];
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    if (xj[j]<0.5) {
      rhoj[j]=1;
    } else {
      rhoj[j]=0.125;
    }
  });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    if (xj[j]<0.5) {
      pj[j]=1;
    } else {
      pj[j]=0.1;
    }
  });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      gammaj[j] = 1.4;
    });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    ej[j] = e(rhoj[j],pj[j],gammaj[j]);
  });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    Ej[j] = ej[j]+0.5*uj[j]*uj[j];
  });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    cj[j] = std::sqrt(gammaj[j]*pj[j]/rhoj[j]);
  });

  Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
    mj[j] = rhoj[j] * Vj[j];
  });

  const double tmax=0.2;
  double t=0;

  int itermax=std::numeric_limits<int>::max();
  int iteration=0;

  while((t<tmax) and (iteration<itermax)) {
    double dt = 0.4*acoustic_dt(Vj, cj);
    if (t+dt<tmax) {
      t+=dt;
    } else {
      dt=tmax-t;
      t=tmax;
    }

    if (iteration%100 == 0) {
      std::cout << "dt=" << dt << "t=" << t << " i=" << iteration << '\n';
    }
    
    Kokkos::View<double*> ur("ur", nr);
    Kokkos::View<double*> pr("pr", nr);

    computeExplicitFluxes(xr, xj,
			  rhoj, uj, pj, cj, Vj,
			  ur, pr);
    
    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      int rm=j;
      int rp=j+1;

      uj[j] += dt/mj[j]*(pr[rm]-pr[rp]);
      Ej[j] += dt/mj[j]*(pr[rm]*ur[rm]-pr[rp]*ur[rp]);
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
	ej[j] = Ej[j] - 0.5 * uj[j]*uj[j];
      });

    Kokkos::parallel_for(nr, KOKKOS_LAMBDA(const int& r){
      xr[r] += dt*ur[r];
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
      xj[j] = 0.5*(xr[j]+xr[j+1]);
      Vj[j] = xr[j+1]-xr[j];
      });

    Kokkos::parallel_for(nj, KOKKOS_LAMBDA(const int& j){
	rhoj[j] = mj[j]/Vj[j];
	pj[j] = p(rhoj[j], ej[j], gammaj[j]);
	cj[j] = std::sqrt(gammaj[j]*pj[j]/rhoj[j]); // inv_mj*vj
      });
    
    ++iteration;
  }

  std::cout << "* " << rang::style::underline << "Final time" << rang::style::reset
	    << ":  " << rang::fgB::green << t << rang::fg::reset << " (" << iteration << " iterations)\n";
  double count_time = timer.seconds();
  std::cout << "* Execution time: " << rang::style::bold << count_time << rang::style::reset << '\n';

  {
    std::ofstream fout("rho");
  
    for (int j=0; j<nj; ++j) {
      fout << xj[j] << ' ' << rhoj[j] << '\n';
    }
  }
  Kokkos::finalize();

  return 0;
}