#ifndef ACOUSTIC_SOLVER_HPP
#define ACOUSTIC_SOLVER_HPP

#include <Kokkos_Core.hpp>
#include <rang.hpp>

#include <BlockPerfectGas.hpp>
#include <PastisAssert.hpp>

#include <TinyVector.hpp>
#include <TinyMatrix.hpp>
#include <Mesh.hpp>
#include <MeshData.hpp>
#include <FiniteVolumesEulerUnknowns.hpp>
#include <BoundaryCondition.hpp>

#include <SubItemValuePerItem.hpp>

template<typename MeshData>
class AcousticSolver
{
  typedef typename MeshData::MeshType MeshType;
  typedef FiniteVolumesEulerUnknowns<MeshData> UnknownsType;

  MeshData& m_mesh_data;
  const MeshType& m_mesh;
  const typename MeshType::Connectivity& m_connectivity;
  const std::vector<BoundaryConditionHandler>& m_boundary_condition_list;

  constexpr static size_t dimension = MeshType::dimension;

  typedef TinyVector<dimension> Rd;
  typedef TinyMatrix<dimension> Rdd;

private:
  struct ReduceMin
  {
  private:
    const Kokkos::View<const double*> x_;

  public:
    typedef Kokkos::View<const double*>::non_const_value_type value_type;

    ReduceMin(const Kokkos::View<const double*>& x) : x_ (x) {}

    typedef Kokkos::View<const double*>::size_type size_type;

    KOKKOS_INLINE_FUNCTION
    void operator() (const size_type i, value_type& update) const
    {
      if (x_(i) < update) {
        update = x_(i);
      }
    }

    KOKKOS_INLINE_FUNCTION
    void join (volatile value_type& dst,
               const volatile value_type& src) const
    {
      if (src < dst) {
        dst = src;
      }
    }

    KOKKOS_INLINE_FUNCTION void
    init (value_type& dst) const
    { // The identity under max is -Inf.
      dst= Kokkos::reduction_identity<value_type>::min();
    }
  };

  KOKKOS_INLINE_FUNCTION
  const Kokkos::View<const double*>
  computeRhoCj(const Kokkos::View<const double*>& rhoj,
               const Kokkos::View<const double*>& cj)
  {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
        m_rhocj[j] = rhoj[j]*cj[j];
      });
    return m_rhocj;
  }

  KOKKOS_INLINE_FUNCTION
  void computeAjr(const Kokkos::View<const double*>& rhocj,
                  const SubItemValuePerItem<Rd>& Cjr,
                  const SubItemValuePerItem<double>& ljr,
                  const SubItemValuePerItem<Rd>& njr)
  {
    Kokkos::parallel_for("new nested Ajr", m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
        const size_t& nb_nodes =m_Ajr.numberOfSubValues(j);
        const double& rho_c = rhocj(j);
        for (size_t r=0; r<nb_nodes; ++r) {
          m_Ajr(j,r) = tensorProduct(rho_c*Cjr(j,r), njr(j,r));
        }
      });
  }

  KOKKOS_INLINE_FUNCTION
  const Kokkos::View<const Rdd*>
  computeAr(const SubItemValuePerItem<Rdd>& Ajr) {
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
        Rdd sum = zero;
        const auto& node_to_cell = m_connectivity.m_node_to_cell_matrix.rowConst(r);
        const auto& node_to_cell_local_node = m_connectivity.m_node_to_cell_local_node_matrix.rowConst(r);
        for (size_t j=0; j<node_to_cell.length; ++j) {
          const unsigned int J = node_to_cell(j);
          const unsigned int R = node_to_cell_local_node(j);
          sum += Ajr(J,R);
        }
        m_Ar(r) = sum;
      });

    return m_Ar;
  }

  KOKKOS_INLINE_FUNCTION
  const Kokkos::View<const Rd*>
  computeBr(const SubItemValuePerItem<Rdd>& Ajr,
            const SubItemValuePerItem<Rd>& Cjr,
            const Kokkos::View<const Rd*>& uj,
            const Kokkos::View<const double*>& pj) {
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
        Rd& br = m_br(r);
        br = zero;
        const auto& node_to_cell = m_connectivity.m_node_to_cell_matrix.rowConst(r);
        const auto& node_to_cell_local_node = m_connectivity.m_node_to_cell_local_node_matrix.rowConst(r);
        for (size_t j=0; j<node_to_cell.length; ++j) {
          const unsigned int J = node_to_cell(j);
          const unsigned int R = node_to_cell_local_node(j);
          br += Ajr(J,R)*uj(J) + pj(J)*Cjr(J,R);
        }
      });

    return m_br;
  }

  void applyBoundaryConditions()
  {
    for (const auto& handler : m_boundary_condition_list) {
      switch (handler.boundaryCondition().type()) {
      case BoundaryCondition::normal_velocity: {
        std::cerr << __FILE__ << ':' << __LINE__  << ": normal_velocity BC NIY\n";
        std::exit(0);
        break;
      }
      case BoundaryCondition::velocity: {
        std::cerr << __FILE__ << ':' << __LINE__  << ": velocity BC NIY\n";
        std::exit(0);
        break;
      }
      case BoundaryCondition::pressure: {
        // const PressureBoundaryCondition& pressure_bc
        //   = dynamic_cast<const PressureBoundaryCondition&>(handler.boundaryCondition());
        std::cerr << __FILE__ << ':' << __LINE__ << ": pressure BC NIY\n";
        std::exit(0);
        break;
      }
      case BoundaryCondition::symmetry: {
        const SymmetryBoundaryCondition<dimension>& symmetry_bc
            = dynamic_cast<const SymmetryBoundaryCondition<dimension>&>(handler.boundaryCondition());
        const Rd& n = symmetry_bc.outgoingNormal();

        const Rdd I = identity;
        const Rdd nxn = tensorProduct(n,n);
        const Rdd P = I-nxn;

        Kokkos::parallel_for(symmetry_bc.numberOfNodes(), KOKKOS_LAMBDA(const int& r_number) {
            const int r = symmetry_bc.nodeList()[r_number];

            m_Ar(r) = P*m_Ar(r)*P + nxn;
            m_br(r) = P*m_br(r);
          });
        break;
      }
      }
    }
  }

  Kokkos::View<Rd*>
  computeUr(const Kokkos::View<const Rdd*>& Ar,
            const Kokkos::View<const Rd*>& br)
  {
    inverse(Ar, m_inv_Ar);
    const Kokkos::View<const Rdd*> invAr = m_inv_Ar;
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
        m_ur[r]=invAr(r)*br(r);
      });

    return m_ur;
  }

  void
  computeFjr(const SubItemValuePerItem<Rdd>& Ajr,
             const Kokkos::View<const Rd*>& ur,
             const SubItemValuePerItem<Rd>& Cjr,
             const Kokkos::View<const Rd*>& uj,
             const Kokkos::View<const double*>& pj)
  {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
        const auto& cell_nodes = m_mesh.connectivity().m_cell_to_node_matrix.rowConst(j);

        for (size_t r=0; r<cell_nodes.length; ++r) {
          m_Fjr(j,r) = Ajr(j,r)*(uj(j)-ur(cell_nodes(r)))+pj(j)*Cjr(j,r);
        }
      });
  }

  void inverse(const Kokkos::View<const Rdd*>& A,
               Kokkos::View<Rdd*>& inv_A) const
  {
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r) {
        inv_A(r) = ::inverse(A(r));
      });
  }

  KOKKOS_INLINE_FUNCTION
  void computeExplicitFluxes(const Kokkos::View<const Rd*>& xr,
                             const Kokkos::View<const Rd*>& xj,
                             const Kokkos::View<const double*>& rhoj,
                             const Kokkos::View<const Rd*>& uj,
                             const Kokkos::View<const double*>& pj,
                             const Kokkos::View<const double*>& cj,
                             const Kokkos::View<const double*>& Vj,
                             const SubItemValuePerItem<Rd>& Cjr,
                             const SubItemValuePerItem<double>& ljr,
                             const SubItemValuePerItem<Rd>& njr) {
    const Kokkos::View<const double*> rhocj  = computeRhoCj(rhoj, cj);
    computeAjr(rhocj, Cjr, ljr, njr);

    const Kokkos::View<const Rdd*> Ar = computeAr(m_Ajr);
    const Kokkos::View<const Rd*> br = computeBr(m_Ajr, Cjr, uj, pj);

    this->applyBoundaryConditions();

    Kokkos::View<Rd*> ur = m_ur;
    ur  = computeUr(Ar, br);
    computeFjr(m_Ajr, ur, Cjr, uj, pj);
  }

  Kokkos::View<Rd*> m_br;
  SubItemValuePerItem<Rdd> m_Ajr;
  Kokkos::View<Rdd*> m_Ar;
  Kokkos::View<Rdd*> m_inv_Ar;
  SubItemValuePerItem<Rd> m_Fjr;
  Kokkos::View<Rd*> m_ur;
  Kokkos::View<double*> m_rhocj;
  Kokkos::View<double*> m_Vj_over_cj;

public:
  AcousticSolver(MeshData& mesh_data,
                 UnknownsType& unknowns,
                 const std::vector<BoundaryConditionHandler>& bc_list)
    : m_mesh_data(mesh_data),
      m_mesh(mesh_data.mesh()),
      m_connectivity(m_mesh.connectivity()),
      m_boundary_condition_list(bc_list),
      m_br("br", m_mesh.numberOfNodes()),
      m_Ajr(m_connectivity.m_node_id_per_cell_matrix),
      m_Ar("Ar", m_mesh.numberOfNodes()),
      m_inv_Ar("inv_Ar", m_mesh.numberOfNodes()),
      m_Fjr(m_connectivity.m_node_id_per_cell_matrix),
      m_ur("ur", m_mesh.numberOfNodes()),
      m_rhocj("rho_c", m_mesh.numberOfCells()),
      m_Vj_over_cj("Vj_over_cj", m_mesh.numberOfCells())
  {
    ;
  }

  KOKKOS_INLINE_FUNCTION
  double acoustic_dt(const Kokkos::View<const double*>& Vj,
                     const Kokkos::View<const double*>& cj) const
  {
    const SubItemValuePerItem<double>& ljr = m_mesh_data.ljr();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
        const auto& cell_nodes = m_mesh.connectivity().m_cell_to_node_matrix.rowConst(j);

        double S = 0;
        for (size_t r=0; r<cell_nodes.length; ++r) {
          S += ljr(j,r);
        }
        m_Vj_over_cj[j] = 2*Vj[j]/(S*cj[j]);
      });

    double dt = std::numeric_limits<double>::max();
    Kokkos::parallel_reduce(m_mesh.numberOfCells(), ReduceMin(m_Vj_over_cj), dt);

    return dt;
  }


  void computeNextStep(const double& t, const double& dt,
                       UnknownsType& unknowns)
  {
    Kokkos::View<double*> rhoj = unknowns.rhoj();
    Kokkos::View<Rd*> uj = unknowns.uj();
    Kokkos::View<double*> Ej = unknowns.Ej();

    Kokkos::View<double*> ej = unknowns.ej();
    Kokkos::View<double*> pj = unknowns.pj();
    Kokkos::View<double*> gammaj = unknowns.gammaj();
    Kokkos::View<double*> cj = unknowns.cj();

    const Kokkos::View<const Rd*> xj = m_mesh_data.xj();
    const Kokkos::View<const double*> Vj = m_mesh_data.Vj();
    const SubItemValuePerItem<Rd>& Cjr = m_mesh_data.Cjr();
    const SubItemValuePerItem<double>& ljr = m_mesh_data.ljr();
    const SubItemValuePerItem<Rd>& njr = m_mesh_data.njr();
    Kokkos::View<Rd*> xr = m_mesh.xr();

    computeExplicitFluxes(xr, xj, rhoj, uj, pj, cj, Vj, Cjr, ljr, njr);

    const SubItemValuePerItem<Rd>& Fjr = m_Fjr;
    const Kokkos::View<const Rd*> ur = m_ur;

    const Kokkos::View<const double*> inv_mj = unknowns.invMj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
        const auto& cell_nodes = m_mesh.connectivity().m_cell_to_node_matrix.rowConst(j);

        Rd momentum_fluxes = zero;
        double energy_fluxes = 0;
        for (size_t R=0; R<cell_nodes.length; ++R) {
          const unsigned int r=cell_nodes(R);
          momentum_fluxes +=  Fjr(j,R);
          energy_fluxes   += (Fjr(j,R), ur[r]);
        }
        uj[j] -= (dt*inv_mj[j]) * momentum_fluxes;
        Ej[j] -= (dt*inv_mj[j]) * energy_fluxes;
      });

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j) {
        ej[j] = Ej[j] - 0.5 * (uj[j],uj[j]);
      });

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const int& r){
        xr[r] += dt*ur[r];
      });

    m_mesh_data.updateAllData();

    const Kokkos::View<const double*> mj = unknowns.mj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const int& j){
        rhoj[j] = mj[j]/Vj[j];
      });
  }
};

#endif // ACOUSTIC_SOLVER_HPP