#ifndef ACOUSTIC_SOLVER_HPP
#define ACOUSTIC_SOLVER_HPP

#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <scheme/AcousticSolverType.hpp>
#include <scheme/BlockPerfectGas.hpp>
#include <scheme/FiniteVolumesEulerUnknowns.hpp>
#include <utils/ArrayUtils.hpp>
#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsAssert.hpp>

#include <rang.hpp>

#include <iostream>

template <typename MeshType>
class AcousticSolver
{
 public:
  class PressureBoundaryCondition;
  class SymmetryBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition =
    std::variant<PressureBoundaryCondition, SymmetryBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  constexpr static size_t Dimension = MeshType::Dimension;

 private:
  using MeshDataType = MeshData<Dimension>;
  using UnknownsType = FiniteVolumesEulerUnknowns<MeshType>;

  std::shared_ptr<const MeshType> m_mesh;
  const typename MeshType::Connectivity& m_connectivity;

  using Rd  = TinyVector<Dimension>;
  using Rdd = TinyMatrix<Dimension>;

  const BoundaryConditionList m_boundary_condition_list;

  const AcousticSolverType m_solver_type;

  void
  _applyPressureBC()
  {
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);
            if constexpr (Dimension == 1) {
              const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();

              const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
              const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

              const auto& node_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              parallel_for(
                node_list.size(), PUGS_LAMBDA(size_t i_node) {
                  const NodeId node_id       = node_list[i_node];
                  const auto& node_cell_list = node_to_cell_matrix[node_id];
                  Assert(node_cell_list.size() == 1);

                  CellId node_cell_id              = node_cell_list[0];
                  size_t node_local_number_in_cell = node_local_numbers_in_their_cells(node_id, 0);

                  m_br[node_id] -= value_list[i_node] * Cjr(node_cell_id, node_local_number_in_cell);
                });
            } else {
              const NodeValuePerFace<const Rd> Nlr = mesh_data.Nlr();

              const auto& face_to_cell_matrix               = m_connectivity.faceToCellMatrix();
              const auto& face_to_node_matrix               = m_connectivity.faceToNodeMatrix();
              const auto& face_local_numbers_in_their_cells = m_connectivity.faceLocalNumbersInTheirCells();
              const auto& face_cell_is_reversed             = m_connectivity.cellFaceIsReversed();

              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id       = face_list[i_face];
                const auto& face_cell_list = face_to_cell_matrix[face_id];
                Assert(face_cell_list.size() == 1);

                CellId face_cell_id              = face_cell_list[0];
                size_t face_local_number_in_cell = face_local_numbers_in_their_cells(face_id, 0);

                const double sign = face_cell_is_reversed(face_cell_id, face_local_number_in_cell) ? -1 : 1;

                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  NodeId node_id = face_nodes[i_node];
                  m_br[node_id] -= sign * value_list[i_face] * Nlr(face_id, i_node);
                }
              }
            }
          }
        },
        boundary_condition);
    }
  }

  void
  _applySymmetryBC()
  {
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const Rd& n = bc.outgoingNormal();

            const Rdd I   = identity;
            const Rdd nxn = tensorProduct(n, n);
            const Rdd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              bc.numberOfNodes(), PUGS_LAMBDA(int r_number) {
                const NodeId r = node_list[r_number];

                m_Ar[r] = P * m_Ar[r] * P + nxn;
                m_br[r] = P * m_br[r];
              });
          }
        },
        boundary_condition);
    }
  }

  void
  _applyVelocityBC()
  {
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();

            parallel_for(
              node_list.size(), PUGS_LAMBDA(size_t i_node) {
                NodeId node_id    = node_list[i_node];
                const auto& value = value_list[i_node];

                m_Ar[node_id] = identity;
                m_br[node_id] = value;
              });
          }
        },
        boundary_condition);
    }
  }

  PUGS_INLINE
  const CellValue<const double>
  computeRhoCj(const CellValue<const double>& rhoj, const CellValue<const double>& cj)
  {
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { m_rhocj[j] = rhoj[j] * cj[j]; });
    return m_rhocj;
  }

  PUGS_INLINE
  void
  _computeGlaceAjr(const CellValue<const double>& rhocj)
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();
    const NodeValuePerCell<const Rd> njr = mesh_data.njr();

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const size_t& nb_nodes = m_Ajr.numberOfSubValues(j);
        const double& rho_c    = rhocj[j];
        for (size_t r = 0; r < nb_nodes; ++r) {
          m_Ajr(j, r) = tensorProduct(rho_c * Cjr(j, r), njr(j, r));
        }
      });
  }

  PUGS_INLINE
  void
  _computeEucclhydAjr(const CellValue<const double>& rhocj)
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerFace<const Rd> Nlr = mesh_data.Nlr();
    const NodeValuePerFace<const Rd> nlr = mesh_data.nlr();

    const auto& face_to_node_matrix = m_connectivity.faceToNodeMatrix();
    const auto& cell_to_node_matrix = m_connectivity.cellToNodeMatrix();
    const auto& cell_to_face_matrix = m_connectivity.cellToFaceMatrix();

    parallel_for(
      m_Ajr.numberOfValues(), PUGS_LAMBDA(size_t jr) { m_Ajr[jr] = zero; });

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        const auto& cell_faces = cell_to_face_matrix[j];

        const double& rho_c = rhocj[j];

        for (size_t L = 0; L < cell_faces.size(); ++L) {
          const FaceId& l        = cell_faces[L];
          const auto& face_nodes = face_to_node_matrix[l];

          auto local_node_number_in_cell = [&](NodeId node_number) {
            for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
              if (node_number == cell_nodes[i_node]) {
                return i_node;
              }
            }
            return std::numeric_limits<size_t>::max();
          };

          for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
            const size_t R = local_node_number_in_cell(face_nodes[rl]);
            m_Ajr(j, R) += tensorProduct(rho_c * Nlr(l, rl), nlr(l, rl));
          }
        }
      });
  }

  PUGS_INLINE
  void
  computeAjr(const CellValue<const double>& rhocj)
  {
    switch (m_solver_type) {
    case AcousticSolverType::Glace: {
      this->_computeGlaceAjr(rhocj);
      break;
    }
    case AcousticSolverType::Eucclhyd: {
      if constexpr (Dimension > 1) {
        this->_computeEucclhydAjr(rhocj);
      } else {
        this->_computeGlaceAjr(rhocj);
      }
      break;
    }
    }
  }

  PUGS_INLINE
  const NodeValue<const Rdd>
  computeAr(const NodeValuePerCell<const Rdd>& Ajr)
  {
    const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        Rdd sum                                    = zero;
        const auto& node_to_cell                   = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemValues(r);

        for (size_t j = 0; j < node_to_cell.size(); ++j) {
          const CellId J       = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          sum += Ajr(J, R);
        }
        m_Ar[r] = sum;
      });

    return m_Ar;
  }

  PUGS_INLINE
  const NodeValue<const Rd>
  computeBr(const CellValue<const Rd>& uj, const CellValue<const double>& pj)
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerCell<const Rd>& Cjr  = mesh_data.Cjr();
    const NodeValuePerCell<const Rdd>& Ajr = m_Ajr;

    const auto& node_to_cell_matrix               = m_connectivity.nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_connectivity.nodeLocalNumbersInTheirCells();

    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        Rd& br                                     = m_br[r];
        br                                         = zero;
        const auto& node_to_cell                   = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemValues(r);
        for (size_t j = 0; j < node_to_cell.size(); ++j) {
          const CellId J       = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          br += Ajr(J, R) * uj[J] + pj[J] * Cjr(J, R);
        }
      });

    return m_br;
  }

  void
  applyBoundaryConditions()
  {
    this->_applyPressureBC();
    this->_applySymmetryBC();
    this->_applyVelocityBC();
  }

  void
  computeUr()
  {
    const NodeValue<const Rdd> Ar = m_Ar;
    const NodeValue<const Rd> br  = m_br;

    inverse(Ar, m_inv_Ar);
    const NodeValue<const Rdd> invAr = m_inv_Ar;
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { m_ur[r] = invAr[r] * br[r]; });
  }

  void
  computeFjr(const CellValue<const Rd>& uj, const CellValue<const double>& pj)
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerCell<const Rd> Cjr  = mesh_data.Cjr();
    const NodeValuePerCell<const Rdd> Ajr = m_Ajr;

    const NodeValue<const Rd> ur = m_ur;

    const auto& cell_to_node_matrix = m_mesh->connectivity().cellToNodeMatrix();

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          m_Fjr(j, r) = Ajr(j, r) * (uj[j] - ur[cell_nodes[r]]) + pj[j] * Cjr(j, r);
        }
      });
  }

  void
  inverse(const NodeValue<const Rdd>& A, NodeValue<Rdd>& inv_A) const
  {
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { inv_A[r] = ::inverse(A[r]); });
  }

  PUGS_INLINE
  void
  computeExplicitFluxes(const CellValue<const double>& rhoj,
                        const CellValue<const Rd>& uj,
                        const CellValue<const double>& pj,
                        const CellValue<const double>& cj)
  {
    const CellValue<const double> rhocj = computeRhoCj(rhoj, cj);
    computeAjr(rhocj);

    NodeValuePerCell<const Rdd> Ajr = m_Ajr;
    this->computeAr(Ajr);
    this->computeBr(uj, pj);

    this->applyBoundaryConditions();

    synchronize(m_Ar);
    synchronize(m_br);

    computeUr();
    computeFjr(uj, pj);
  }

  NodeValue<Rd> m_br;
  NodeValuePerCell<Rdd> m_Ajr;
  NodeValue<Rdd> m_Ar;
  NodeValue<Rdd> m_inv_Ar;
  NodeValuePerCell<Rd> m_Fjr;
  NodeValue<Rd> m_ur;
  CellValue<double> m_rhocj;
  CellValue<double> m_Vj_over_cj;

 public:
  AcousticSolver(std::shared_ptr<const MeshType> p_mesh,
                 const BoundaryConditionList& bc_list,
                 const AcousticSolverType solver_type)
    : m_mesh(p_mesh),
      m_connectivity(m_mesh->connectivity()),
      m_boundary_condition_list(bc_list),
      m_solver_type(solver_type),
      m_br(m_connectivity),
      m_Ajr(m_connectivity),
      m_Ar(m_connectivity),
      m_inv_Ar(m_connectivity),
      m_Fjr(m_connectivity),
      m_ur(m_connectivity),
      m_rhocj(m_connectivity),
      m_Vj_over_cj(m_connectivity)
  {
    ;
  }

  PUGS_INLINE
  double
  acoustic_dt(const CellValue<const double>& Vj, const CellValue<const double>& cj) const
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValuePerCell<const double>& ljr = mesh_data.ljr();
    const auto& cell_to_node_matrix           = m_mesh->connectivity().cellToNodeMatrix();

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        double S = 0;
        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          S += ljr(j, r);
        }
        m_Vj_over_cj[j] = 2 * Vj[j] / (S * cj[j]);
      });

    return min(m_Vj_over_cj);
  }

  [[nodiscard]] std::shared_ptr<const MeshType>
  computeNextStep(double dt, UnknownsType& unknowns)
  {
    CellValue<double>& rhoj = unknowns.rhoj();
    CellValue<Rd>& uj       = unknowns.uj();
    CellValue<double>& Ej   = unknowns.Ej();

    CellValue<double>& ej = unknowns.ej();
    CellValue<double>& pj = unknowns.pj();
    CellValue<double>& cj = unknowns.cj();

    MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const CellValue<const double> Vj = mesh_data.Vj();

    computeExplicitFluxes(rhoj, uj, pj, cj);

    const NodeValuePerCell<Rd>& Fjr = m_Fjr;
    const NodeValue<const Rd> ur    = m_ur;
    const auto& cell_to_node_matrix = m_mesh->connectivity().cellToNodeMatrix();

    const CellValue<const double> inv_mj = unknowns.invMj();
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        Rd momentum_fluxes   = zero;
        double energy_fluxes = 0;
        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          const NodeId r = cell_nodes[R];
          momentum_fluxes += Fjr(j, R);
          energy_fluxes += (Fjr(j, R), ur[r]);
        }
        uj[j] -= (dt * inv_mj[j]) * momentum_fluxes;
        Ej[j] -= (dt * inv_mj[j]) * energy_fluxes;
      });

    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { ej[j] = Ej[j] - 0.5 * (uj[j], uj[j]); });

    NodeValue<Rd> new_xr = copy(m_mesh->xr());
    parallel_for(
      m_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * ur[r]; });

    m_mesh                         = std::make_shared<MeshType>(m_mesh->shared_connectivity(), new_xr);
    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*m_mesh).Vj();

    const CellValue<const double> mj = unknowns.mj();
    parallel_for(
      m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { rhoj[j] = mj[j] / new_Vj[j]; });

    return m_mesh;
  }
};

template <typename MeshType>
class AcousticSolver<MeshType>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <>
class AcousticSolver<Mesh<Connectivity<1>>>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const NodeId> m_face_list;

 public:
  const Array<const NodeId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const NodeId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <typename MeshType>
class AcousticSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <typename MeshType>
class AcousticSolver<MeshType>::SymmetryBoundaryCondition
{
 public:
  static constexpr size_t Dimension = MeshType::Dimension;

  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<Dimension> m_mesh_flat_node_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_flat_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  SymmetryBoundaryCondition(const MeshFlatNodeBoundary<Dimension>& mesh_flat_node_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

#endif   // ACOUSTIC_SOLVER_HPP
