#ifndef ACOUSTIC_SOLVER_HPP
#define ACOUSTIC_SOLVER_HPP

#include <rang.hpp>

#include <ArrayUtils.hpp>

#include <BlockPerfectGas.hpp>
#include <PastisAssert.hpp>

#include <TinyVector.hpp>
#include <TinyMatrix.hpp>
#include <Mesh.hpp>
#include <MeshData.hpp>
#include <FiniteVolumesEulerUnknowns.hpp>
#include <BoundaryCondition.hpp>

#include <SubItemValuePerItem.hpp>

template<typename MeshData>
class AcousticSolver
{
  using MeshType = typename MeshData::MeshType;
  using UnknownsType = FiniteVolumesEulerUnknowns<MeshData>;

  MeshData& m_mesh_data;
  const MeshType& m_mesh;
  const typename MeshType::Connectivity& m_connectivity;
  const std::vector<BoundaryConditionHandler>& m_boundary_condition_list;

  constexpr static size_t dimension = MeshType::dimension;

  using Rd = TinyVector<dimension>;
  using Rdd = TinyMatrix<dimension>;

 private:
  KOKKOS_INLINE_FUNCTION
  const CellValue<const double>
  computeRhoCj(const CellValue<const double>& rhoj,
               const CellValue<const double>& cj)
  {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j) {
        m_rhocj[j] = rhoj[j]*cj[j];
      });
    return m_rhocj;
  }

  KOKKOS_INLINE_FUNCTION
  void computeAjr(const CellValue<const double>& rhocj,
                  const NodeValuePerCell<const Rd>& Cjr,
                  const NodeValuePerCell<const double>& ljr,
                  const NodeValuePerCell<const Rd>& njr)
  {
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j) {
        const size_t& nb_nodes =m_Ajr.numberOfSubValues(j);
        const double& rho_c = rhocj[j];
        for (size_t r=0; r<nb_nodes; ++r) {
          m_Ajr(j,r) = tensorProduct(rho_c*Cjr(j,r), njr(j,r));
        }
      });
  }

  KOKKOS_INLINE_FUNCTION
  const NodeValue<const Rdd>
  computeAr(const NodeValuePerCell<const Rdd>& Ajr) {
    const auto& node_to_cell_matrix
        = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells
        = m_connectivity.nodeLocalNumbersInTheirCells();

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const NodeId& r) {
        Rdd sum = zero;
        const auto& node_to_cell = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells
            = node_local_numbers_in_their_cells.itemValues(r);

        for (size_t j=0; j<node_to_cell.size(); ++j) {
          const CellId J = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          sum += Ajr(J,R);
        }
        m_Ar[r] = sum;
      });

    return m_Ar;
  }

  KOKKOS_INLINE_FUNCTION
  const NodeValue<const Rd>
  computeBr(const NodeValuePerCell<Rdd>& Ajr,
            const NodeValuePerCell<const Rd>& Cjr,
            const CellValue<const Rd>& uj,
            const CellValue<const double>& pj) {
    const auto& node_to_cell_matrix
        = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells
        = m_connectivity.nodeLocalNumbersInTheirCells();

    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const NodeId& r) {
        Rd& br = m_br[r];
        br = zero;
        const auto& node_to_cell = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells
            = node_local_numbers_in_their_cells.itemValues(r);
        for (size_t j=0; j<node_to_cell.size(); ++j) {
          const CellId J = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          br += Ajr(J,R)*uj[J] + pj[J]*Cjr(J,R);
        }
      });

    return m_br;
  }

  void applyBoundaryConditions()
  {
    for (const auto& handler : m_boundary_condition_list) {
      switch (handler.boundaryCondition().type()) {
        case BoundaryCondition::normal_velocity: {
          std::cerr << __FILE__ << ':' << __LINE__  << ": normal_velocity BC NIY\n";
          std::exit(0);
          break;
        }
        case BoundaryCondition::velocity: {
          std::cerr << __FILE__ << ':' << __LINE__  << ": velocity BC NIY\n";
          std::exit(0);
          break;
        }
        case BoundaryCondition::pressure: {
          // const PressureBoundaryCondition& pressure_bc
          //   = dynamic_cast<const PressureBoundaryCondition&>(handler.boundaryCondition());
          std::cerr << __FILE__ << ':' << __LINE__ << ": pressure BC NIY\n";
          std::exit(0);
          break;
        }
        case BoundaryCondition::symmetry: {
          const SymmetryBoundaryCondition<dimension>& symmetry_bc
              = dynamic_cast<const SymmetryBoundaryCondition<dimension>&>(handler.boundaryCondition());
          const Rd& n = symmetry_bc.outgoingNormal();

          const Rdd I = identity;
          const Rdd nxn = tensorProduct(n,n);
          const Rdd P = I-nxn;

          const Array<const NodeId>& node_list
              = symmetry_bc.nodeList();
          Kokkos::parallel_for(symmetry_bc.numberOfNodes(), KOKKOS_LAMBDA(const int& r_number) {
              const NodeId r = node_list[r_number];

              m_Ar[r] = P*m_Ar[r]*P + nxn;
              m_br[r] = P*m_br[r];
            });
          break;
        }
      }
    }
  }

  NodeValue<Rd>
  computeUr(const NodeValue<const Rdd>& Ar,
            const NodeValue<const Rd>& br)
  {
    inverse(Ar, m_inv_Ar);
    const NodeValue<const Rdd> invAr = m_inv_Ar;
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const NodeId& r) {
        m_ur[r]=invAr[r]*br[r];
      });

    return m_ur;
  }

  void
  computeFjr(const NodeValuePerCell<Rdd>& Ajr,
             const NodeValue<const Rd>& ur,
             const NodeValuePerCell<const Rd>& Cjr,
             const CellValue<const Rd>& uj,
             const CellValue<const double>& pj)
  {
    const auto& cell_to_node_matrix
        = m_mesh.connectivity().cellToNodeMatrix();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t r=0; r<cell_nodes.size(); ++r) {
          m_Fjr(j,r) = Ajr(j,r)*(uj[j]-ur[cell_nodes[r]])+pj[j]*Cjr(j,r);
        }
      });
  }

  void inverse(const NodeValue<const Rdd>& A,
               NodeValue<Rdd>& inv_A) const
  {
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const NodeId& r) {
        inv_A[r] = ::inverse(A[r]);
      });
  }

  KOKKOS_INLINE_FUNCTION
  void computeExplicitFluxes(const NodeValue<const Rd>& xr,
                             const CellValue<const Rd>& xj,
                             const CellValue<const double>& rhoj,
                             const CellValue<const Rd>& uj,
                             const CellValue<const double>& pj,
                             const CellValue<const double>& cj,
                             const CellValue<const double>& Vj,
                             const NodeValuePerCell<const Rd>& Cjr,
                             const NodeValuePerCell<const double>& ljr,
                             const NodeValuePerCell<const Rd>& njr) {
    const CellValue<const double> rhocj = computeRhoCj(rhoj, cj);
    computeAjr(rhocj, Cjr, ljr, njr);

    NodeValuePerCell<const Rdd> Ajr = m_Ajr;
    const NodeValue<const Rdd> Ar = computeAr(Ajr);
    const NodeValue<const Rd> br = computeBr(m_Ajr, Cjr, uj, pj);

    this->applyBoundaryConditions();

    NodeValue<Rd>& ur = m_ur;
    ur = computeUr(Ar, br);
    computeFjr(m_Ajr, ur, Cjr, uj, pj);
  }

  NodeValue<Rd> m_br;
  NodeValuePerCell<Rdd> m_Ajr;
  NodeValue<Rdd> m_Ar;
  NodeValue<Rdd> m_inv_Ar;
  NodeValuePerCell<Rd> m_Fjr;
  NodeValue<Rd> m_ur;
  CellValue<double> m_rhocj;
  CellValue<double> m_Vj_over_cj;

 public:
  AcousticSolver(MeshData& mesh_data,
                 UnknownsType& unknowns,
                 const std::vector<BoundaryConditionHandler>& bc_list)
      : m_mesh_data(mesh_data),
        m_mesh(mesh_data.mesh()),
        m_connectivity(m_mesh.connectivity()),
        m_boundary_condition_list(bc_list),
        m_br(m_connectivity),
        m_Ajr(m_connectivity),
        m_Ar(m_connectivity),
        m_inv_Ar(m_connectivity),
        m_Fjr(m_connectivity),
        m_ur(m_connectivity),
        m_rhocj(m_connectivity),
        m_Vj_over_cj(m_connectivity)
  {
    ;
  }

  KOKKOS_INLINE_FUNCTION
  double acoustic_dt(const CellValue<const double>& Vj,
                     const CellValue<const double>& cj) const
  {
    const NodeValuePerCell<const double>& ljr = m_mesh_data.ljr();
    const auto& cell_to_node_matrix
        = m_mesh.connectivity().cellToNodeMatrix();

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j){
        const auto& cell_nodes = cell_to_node_matrix[j];

        double S = 0;
        for (size_t r=0; r<cell_nodes.size(); ++r) {
          S += ljr(j,r);
        }
        m_Vj_over_cj[j] = 2*Vj[j]/(S*cj[j]);
      });

    return ReduceMin(m_Vj_over_cj);
  }

  void computeNextStep(const double& t, const double& dt,
                       UnknownsType& unknowns)
  {
    CellValue<double>& rhoj = unknowns.rhoj();
    CellValue<Rd>& uj = unknowns.uj();
    CellValue<double>& Ej = unknowns.Ej();

    CellValue<double>& ej = unknowns.ej();
    CellValue<double>& pj = unknowns.pj();
    CellValue<double>& cj = unknowns.cj();

    const CellValue<const Rd>& xj = m_mesh_data.xj();
    const CellValue<const double>& Vj = m_mesh_data.Vj();
    const NodeValuePerCell<const Rd>& Cjr = m_mesh_data.Cjr();
    const NodeValuePerCell<const double>& ljr = m_mesh_data.ljr();
    const NodeValuePerCell<const Rd>& njr = m_mesh_data.njr();
    const NodeValue<const Rd>& xr = m_mesh.xr();

    computeExplicitFluxes(xr, xj, rhoj, uj, pj, cj, Vj, Cjr, ljr, njr);

    const NodeValuePerCell<Rd>& Fjr = m_Fjr;
    const NodeValue<const Rd> ur = m_ur;
    const auto& cell_to_node_matrix
        = m_mesh.connectivity().cellToNodeMatrix();

    const CellValue<const double> inv_mj = unknowns.invMj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        Rd momentum_fluxes = zero;
        double energy_fluxes = 0;
        for (size_t R=0; R<cell_nodes.size(); ++R) {
          const NodeId r=cell_nodes[R];
          momentum_fluxes +=  Fjr(j,R);
          energy_fluxes   += (Fjr(j,R), ur[r]);
        }
        uj[j] -= (dt*inv_mj[j]) * momentum_fluxes;
        Ej[j] -= (dt*inv_mj[j]) * energy_fluxes;
      });

    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j) {
        ej[j] = Ej[j] - 0.5 * (uj[j],uj[j]);
      });

    NodeValue<Rd> mutable_xr = m_mesh.mutableXr();
    Kokkos::parallel_for(m_mesh.numberOfNodes(), KOKKOS_LAMBDA(const NodeId& r){
        mutable_xr[r] += dt*ur[r];
      });
    m_mesh_data.updateAllData();

    const CellValue<const double> mj = unknowns.mj();
    Kokkos::parallel_for(m_mesh.numberOfCells(), KOKKOS_LAMBDA(const CellId& j){
        rhoj[j] = mj[j]/Vj[j];
      });
  }
};

#endif // ACOUSTIC_SOLVER_HPP