#ifndef ACOUSTIC_SOLVER_HPP
#define ACOUSTIC_SOLVER_HPP

#include <rang.hpp>

#include <utils/ArrayUtils.hpp>

#include <scheme/BlockPerfectGas.hpp>
#include <utils/PugsAssert.hpp>

#include <scheme/BoundaryCondition.hpp>
#include <scheme/FiniteVolumesEulerUnknowns.hpp>

#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>

#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>

#include <mesh/ItemValueUtils.hpp>
#include <mesh/SubItemValuePerItem.hpp>

#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>

#include <iostream>

template <typename MeshType>
class AcousticSolver
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using MeshDataType = MeshData<Dimension>;
  using UnknownsType = FiniteVolumesEulerUnknowns<MeshType>;

  std::shared_ptr<const MeshType> m_mesh;
  const typename MeshType::Connectivity& m_connectivity;
  const std::vector<BoundaryConditionHandler>& m_boundary_condition_list;

  using Rd  = TinyVector<Dimension>;
  using Rdd = TinyMatrix<Dimension>;

 private:
  PUGS_INLINE
  const CellValue<const double>
  computeRhoCj(const CellValue<const double>& rhoj, const CellValue<const double>& cj)
  {
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { m_rhocj[j] = rhoj[j] * cj[j]; });
    return m_rhocj;
  }

  PUGS_INLINE
  void
  computeAjr(const CellValue<const double>& rhocj,
             const NodeValuePerCell<const Rd>& Cjr,
             const NodeValuePerCell<const double>& /* ljr */,
             const NodeValuePerCell<const Rd>& njr)
  {
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const size_t& nb_nodes = m_Ajr.numberOfSubValues(j);
        const double& rho_c    = rhocj[j];
        for (size_t r = 0; r < nb_nodes; ++r) {
          m_Ajr(j, r) = tensorProduct(rho_c * Cjr(j, r), njr(j, r));
        }
      });
  }

  PUGS_INLINE
  const NodeValue<const Rdd>
  computeAr(const NodeValuePerCell<const Rdd>& Ajr)
  {
    const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        Rdd sum                                    = zero;
        const auto& node_to_cell                   = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemValues(r);

        for (size_t j = 0; j < node_to_cell.size(); ++j) {
          const CellId J       = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          sum += Ajr(J, R);
        }
        m_Ar[r] = sum;
      });

    return m_Ar;
  }

  PUGS_INLINE
  const NodeValue<const Rd>
  computeBr(const NodeValuePerCell<Rdd>& Ajr,
            const NodeValuePerCell<const Rd>& Cjr,
            const CellValue<const Rd>& uj,
            const CellValue<const double>& pj)
  {
    const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        Rd& br                                     = m_br[r];
        br                                         = zero;
        const auto& node_to_cell                   = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemValues(r);
        for (size_t j = 0; j < node_to_cell.size(); ++j) {
          const CellId J       = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          br += Ajr(J, R) * uj[J] + pj[J] * Cjr(J, R);
        }
      });

    return m_br;
  }

  void
  applyBoundaryConditions()
  {
    for (const auto& handler : m_boundary_condition_list) {
      switch (handler.boundaryCondition().type()) {
      case BoundaryCondition::normal_velocity: {
        throw NotImplementedError("normal_velocity BC");
      }
      case BoundaryCondition::velocity: {
        throw NotImplementedError("velocity BC");
      }
      case BoundaryCondition::pressure: {
        const PressureBoundaryCondition<Dimension>& pressure_bc =
          dynamic_cast<const PressureBoundaryCondition<Dimension>&>(handler.boundaryCondition());
        if constexpr (Dimension == 1) {
          MeshData<Dimension>& mesh_data       = MeshDataManager::instance().getMeshData(*m_mesh);
          const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();

          const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
          const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

          const auto& node_list  = pressure_bc.faceList();
          const auto& value_list = pressure_bc.valueList();
          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id       = node_list[i_node];
            const auto& node_cell_list = node_to_cell_matrix[node_id];
            Assert(node_cell_list.size() == 1);

            CellId node_cell_id              = node_cell_list[0];
            size_t node_local_number_in_cell = node_local_numbers_in_their_cells(node_id, 0);

            m_br[node_id] -= value_list[i_node] * Cjr(node_cell_id, node_local_number_in_cell);
          }
        } else {
          throw NotImplementedError("pressure bc in dimension>1");
        }
        break;
      }
      case BoundaryCondition::symmetry: {
        const SymmetryBoundaryCondition<Dimension>& symmetry_bc =
          dynamic_cast<const SymmetryBoundaryCondition<Dimension>&>(handler.boundaryCondition());
        const Rd& n = symmetry_bc.outgoingNormal();

        const Rdd I   = identity;
        const Rdd nxn = tensorProduct(n, n);
        const Rdd P   = I - nxn;

        const Array<const NodeId>& node_list = symmetry_bc.nodeList();
        parallel_for(
          symmetry_bc.numberOfNodes(), PUGS_LAMBDA(int r_number) {
            const NodeId r = node_list[r_number];

            m_Ar[r] = P * m_Ar[r] * P + nxn;
            m_br[r] = P * m_br[r];
          });
        break;
      }
      }
    }
  }

  NodeValue<Rd>
  computeUr(const NodeValue<const Rdd>& Ar, const NodeValue<const Rd>& br)
  {
    inverse(Ar, m_inv_Ar);
    const NodeValue<const Rdd> invAr = m_inv_Ar;
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { m_ur[r] = invAr[r] * br[r]; });

    return m_ur;
  }

  void
  computeFjr(const NodeValuePerCell<Rdd>& Ajr,
             const NodeValue<const Rd>& ur,
             const NodeValuePerCell<const Rd>& Cjr,
             const CellValue<const Rd>& uj,
             const CellValue<const double>& pj)
  {
    const auto& cell_to_node_matrix = m_mesh->connectivity().cellToNodeMatrix();

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          m_Fjr(j, r) = Ajr(j, r) * (uj[j] - ur[cell_nodes[r]]) + pj[j] * Cjr(j, r);
        }
      });
  }

  void
  inverse(const NodeValue<const Rdd>& A, NodeValue<Rdd>& inv_A) const
  {
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { inv_A[r] = ::inverse(A[r]); });
  }

  PUGS_INLINE
  void
  computeExplicitFluxes(const CellValue<const double>& rhoj,
                        const CellValue<const Rd>& uj,
                        const CellValue<const double>& pj,
                        const CellValue<const double>& cj,
                        const NodeValuePerCell<const Rd>& Cjr,
                        const NodeValuePerCell<const double>& ljr,
                        const NodeValuePerCell<const Rd>& njr)
  {
    const CellValue<const double> rhocj = computeRhoCj(rhoj, cj);
    computeAjr(rhocj, Cjr, ljr, njr);

    NodeValuePerCell<const Rdd> Ajr = m_Ajr;
    this->computeAr(Ajr);
    this->computeBr(m_Ajr, Cjr, uj, pj);

    this->applyBoundaryConditions();

    synchronize(m_Ar);
    synchronize(m_br);

    NodeValue<Rd>& ur = m_ur;
    ur                = computeUr(m_Ar, m_br);
    computeFjr(m_Ajr, ur, Cjr, uj, pj);
  }

  NodeValue<Rd> m_br;
  NodeValuePerCell<Rdd> m_Ajr;
  NodeValue<Rdd> m_Ar;
  NodeValue<Rdd> m_inv_Ar;
  NodeValuePerCell<Rd> m_Fjr;
  NodeValue<Rd> m_ur;
  CellValue<double> m_rhocj;
  CellValue<double> m_Vj_over_cj;

 public:
  AcousticSolver(std::shared_ptr<const MeshType> p_mesh, const std::vector<BoundaryConditionHandler>& bc_list)
    : m_mesh(p_mesh),
      m_connectivity(m_mesh->connectivity()),
      m_boundary_condition_list(bc_list),
      m_br(m_connectivity),
      m_Ajr(m_connectivity),
      m_Ar(m_connectivity),
      m_inv_Ar(m_connectivity),
      m_Fjr(m_connectivity),
      m_ur(m_connectivity),
      m_rhocj(m_connectivity),
      m_Vj_over_cj(m_connectivity)
  {
    ;
  }

  PUGS_INLINE
  double
  acoustic_dt(const CellValue<const double>& Vj, const CellValue<const double>& cj) const
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerCell<const double>& ljr = mesh_data.ljr();
    const auto& cell_to_node_matrix           = m_mesh->connectivity().cellToNodeMatrix();

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        double S = 0;
        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          S += ljr(j, r);
        }
        m_Vj_over_cj[j] = 2 * Vj[j] / (S * cj[j]);
      });

    return min(m_Vj_over_cj);
  }

  [[nodiscard]] std::shared_ptr<const MeshType>
  computeNextStep(double dt, UnknownsType& unknowns)
  {
    CellValue<double>& rhoj = unknowns.rhoj();
    CellValue<Rd>& uj       = unknowns.uj();
    CellValue<double>& Ej   = unknowns.Ej();

    CellValue<double>& ej = unknowns.ej();
    CellValue<double>& pj = unknowns.pj();
    CellValue<double>& cj = unknowns.cj();

    MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const CellValue<const double> Vj         = mesh_data.Vj();
    const NodeValuePerCell<const Rd> Cjr     = mesh_data.Cjr();
    const NodeValuePerCell<const double> ljr = mesh_data.ljr();
    const NodeValuePerCell<const Rd> njr     = mesh_data.njr();

    computeExplicitFluxes(rhoj, uj, pj, cj, Cjr, ljr, njr);

    const NodeValuePerCell<Rd>& Fjr = m_Fjr;
    const NodeValue<const Rd> ur    = m_ur;
    const auto& cell_to_node_matrix = m_mesh->connectivity().cellToNodeMatrix();

    const CellValue<const double> inv_mj = unknowns.invMj();
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        Rd momentum_fluxes   = zero;
        double energy_fluxes = 0;
        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          const NodeId r = cell_nodes[R];
          momentum_fluxes += Fjr(j, R);
          energy_fluxes += (Fjr(j, R), ur[r]);
        }
        uj[j] -= (dt * inv_mj[j]) * momentum_fluxes;
        Ej[j] -= (dt * inv_mj[j]) * energy_fluxes;
      });

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { ej[j] = Ej[j] - 0.5 * (uj[j], uj[j]); });

    NodeValue<Rd> new_xr = copy(m_mesh->xr());
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * ur[r]; });

    m_mesh                         = std::make_shared<MeshType>(m_mesh->shared_connectivity(), new_xr);
    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*m_mesh).Vj();

    const CellValue<const double> mj = unknowns.mj();
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { rhoj[j] = mj[j] / new_Vj[j]; });

    return m_mesh;
  }
};

#endif   // ACOUSTIC_SOLVER_HPP