#include <scheme/ImplicitAcousticSolver.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/IZoneDescriptor.hpp>
#include <mesh/MeshCellZone.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/MeshFlatNodeBoundary.hpp>
#include <mesh/MeshLineNodeBoundary.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/AxisBoundaryConditionDescriptor.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <mesh/ItemValueVariant.hpp>
#include <output/NamedItemValueVariant.hpp>
#include <output/VTKWriter.hpp>
#include <utils/Timer.hpp>

#include <variant>
#include <vector>

// VTKWriter vtk_writer("imp/debug", 0);

int count_newton   = 0;
int count_Djr      = 0;
int count_timestep = 0;

Timer solver_t;
Timer get_A_t;
Timer getF_t;
Timer getGradF_t;
Timer HJ_A_t;

int count_getA     = 0;
int count_getGradF = 0;

std::shared_ptr<const DiscreteFunctionVariant>
local_acoustic_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c_v)
{
  const auto& c = c_v->get<DiscreteFunctionP0<const double>>();

  return std::visit(
    [&](auto&& p_mesh) -> std::shared_ptr<const DiscreteFunctionVariant> {
      const auto& mesh = *p_mesh;

      using MeshType = decltype(mesh);
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        const auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();
        const auto Sj = MeshDataManager::instance().getMeshData(mesh).sumOverRLjr();

        DiscreteFunctionP0<double> local_dt(c.meshVariant());
        parallel_for(
          mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { local_dt[j] = 2 * Vj[j] / (Sj[j] * c[j]); });

        return std::make_shared<const DiscreteFunctionVariant>(local_dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    c.meshVariant()->variant());
}

double
acoustic_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
            const std::vector<std::shared_ptr<const IZoneDescriptor>>& explicit_zone_list)
{
  const auto& c = c_v->get<DiscreteFunctionP0<const double>>();

  return std::visit(
    [&](auto&& p_mesh) -> double {
      const auto& mesh = *p_mesh;

      using MeshType = decltype(mesh);
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        const auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();
        const auto Sj = MeshDataManager::instance().getMeshData(mesh).sumOverRLjr();

        DiscreteFunctionP0<double> local_dt(p_mesh);
        local_dt.fill(std::numeric_limits<double>::max());
        for (auto explicit_zone : explicit_zone_list) {
          auto mesh_cell_zone   = getMeshCellZone(mesh, *explicit_zone);
          const auto& cell_list = mesh_cell_zone.cellList();
          parallel_for(
            cell_list.size(), PUGS_LAMBDA(size_t i_cell) {
              const CellId cell_id = cell_list[i_cell];
              local_dt[cell_id]    = 2 * Vj[cell_id] / (Sj[cell_id] * c[cell_id]);
            });
        }

        return min(local_dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    c.meshVariant()->variant());
}

template <MeshConcept MeshType>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver final
  : public ImplicitAcousticSolverHandler::IImplicitAcousticSolver
{
 private:
  constexpr static size_t Dimension = MeshType::Dimension;
  using Rdxd                        = TinyMatrix<Dimension>;
  using Rd                          = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class AxisBoundaryCondition;
  class PressureBoundaryCondition;
  class SymmetryBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::
    variant<AxisBoundaryCondition, PressureBoundaryCondition, SymmetryBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  const SolverType m_solver_type;
  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  CellValue<const double> m_inv_gamma;
  CellValue<const double> m_g_1_exp_S_Cv_inv_g;
  CellValue<const double> m_tau_iter;

  CellValue<const Rd> m_u;
  CellValue<const double> m_tau;
  CellValue<const double> m_Mj;

  NodeValuePerCell<const Rdxd> m_Ajr;
  NodeValue<const Rdxd> m_inv_Ar;

  NodeValuePerCell<const Rd> m_Djr;

  CellValue<const Rd> m_predicted_u;
  CellValue<const double> m_predicted_p;

  CellValue<const bool> m_is_implicit_cell;
  CellValue<const int> m_implicit_cell_index;
  size_t m_number_of_implicit_cells;

  NodeValue<const bool> m_is_implicit_node;
  NodeValue<const int> m_implicit_node_index;
  size_t m_number_of_implicit_nodes;

  NodeValue<const double>
  _getRhoCr(const DiscreteScalarFunction& rho, const DiscreteScalarFunction& c) const
  {
    Assert(rho.meshVariant()->id() == c.meshVariant()->id());
    NodeValue<double> rhocr{m_mesh.connectivity()};

    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        const auto& node_cells = node_to_cell_matrix[r];
        const size_t nb_cells  = node_cells.size();
        double rhoc            = 0;

        for (size_t J = 0; J < nb_cells; ++J) {
          CellId j = node_cells[J];
          rhoc += rho[j] * c[j];
        }
        rhocr[r] = rhoc * (1. / nb_cells);
      });

    return rhocr;
  }

  NodeValuePerCell<const Rdxd>
  _computeAjr(const DiscreteScalarFunction& rho, const DiscreteScalarFunction& c) const
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

    const NodeValuePerCell<const Rd> Cjr_n = mesh_data.Cjr();
    const NodeValuePerCell<const Rd> njr_n = mesh_data.njr();

    NodeValuePerCell<Rdxd> Ajr{m_mesh.connectivity()};
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    switch (m_solver_type) {
    case SolverType::Glace1State: {
      NodeValue<const double> rhocr = _getRhoCr(rho, c);

      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const size_t& nb_nodes = Ajr.numberOfSubValues(j);
          for (size_t r = 0; r < nb_nodes; ++r) {
            const NodeId node_id = cell_to_node_matrix[j][r];
            Ajr(j, r)            = tensorProduct(rhocr[node_id] * Cjr_n(j, r), njr_n(j, r));
          }
        });
      break;
    }
    case SolverType::Glace2States: {
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const size_t& nb_nodes = Ajr.numberOfSubValues(j);
          for (size_t r = 0; r < nb_nodes; ++r) {
            Ajr(j, r) = tensorProduct(rho[j] * c[j] * Cjr_n(j, r), njr_n(j, r));
          }
        });
      break;
    }
    case SolverType::Eucclhyd: {
      if constexpr (Dimension > 1) {
        const NodeValuePerFace<const Rd> Nlr = mesh_data.Nlr();
        const NodeValuePerFace<const Rd> nlr = mesh_data.nlr();
        const auto& face_to_node_matrix      = m_mesh.connectivity().faceToNodeMatrix();
        const auto& cell_to_face_matrix      = m_mesh.connectivity().cellToFaceMatrix();
        parallel_for(
          Ajr.numberOfValues(), PUGS_LAMBDA(size_t jr) { Ajr[jr] = zero; });

        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            const auto& cell_nodes = cell_to_node_matrix[j];

            const auto& cell_faces = cell_to_face_matrix[j];

            const double& rho_c = rho[j] * c[j];

            for (size_t L = 0; L < cell_faces.size(); ++L) {
              const FaceId& l        = cell_faces[L];
              const auto& face_nodes = face_to_node_matrix[l];

              auto local_node_number_in_cell = [&](NodeId node_number) {
                for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                  if (node_number == cell_nodes[i_node]) {
                    return i_node;
                  }
                }
                return std::numeric_limits<size_t>::max();
              };

              for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                const size_t R = local_node_number_in_cell(face_nodes[rl]);
                Ajr(j, R) += tensorProduct(rho_c * Nlr(l, rl), nlr(l, rl));
              }
            }
          });

        break;
      } else {
        throw NotImplementedError("Eucclhyd switch is not implemented yet in 1d");
      }
    }
    }

    return Ajr;
  }

  NodeValue<const Rdxd>
  _computeAr(const NodeValuePerCell<const Rdxd>& Ajr) const
  {
    const auto& node_to_cell_matrix               = m_mesh.connectivity().nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

    NodeValue<Rdxd> Ar{m_mesh.connectivity()};

    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
        Rdxd sum                                   = zero;
        const auto& node_to_cell                   = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

        for (size_t i_cell = 0; i_cell < node_to_cell.size(); ++i_cell) {
          const unsigned int i_node = node_local_number_in_its_cells[i_cell];
          const CellId cell_id      = node_to_cell[i_cell];
          sum += Ajr(cell_id, i_node);
        }
        Ar[node_id] = sum;
      });

    NodeValue<bool> has_boundary_condition{m_mesh.connectivity()};
    has_boundary_condition.fill(false);

    // velocity bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];

                if (not has_boundary_condition[node_id]) {
                  Ar[node_id] = identity;

                  has_boundary_condition[node_id] = true;
                }
              });
          }
        },
        boundary_condition);
    }

    if constexpr (Dimension > 1) {
      // axis bc
      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<AxisBoundaryCondition, T>) {
              const Rd& t = bc.direction();

              const Rdxd I   = identity;
              const Rdxd txt = tensorProduct(t, t);

              const Array<const NodeId>& node_list = bc.nodeList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];
                if (not has_boundary_condition[node_id]) {
                  Ar[node_id] = txt * Ar[node_id] * txt + (I - txt);

                  has_boundary_condition[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
    }

    // symmetry bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const Rd& n = bc.outgoingNormal();

            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              if (not has_boundary_condition[node_id]) {
                Ar[node_id] = P * Ar[node_id] * P + nxn;

                has_boundary_condition[node_id] = true;
              }
            }
          }
        },
        boundary_condition);
    }

    return Ar;
  }

  NodeValue<const Rdxd>
  _computeInvAr(const NodeValue<const Rdxd>& Ar) const
  {
    NodeValue<Rdxd> inv_Ar{m_mesh.connectivity()};

    for (NodeId node_id = 0; node_id < m_mesh.numberOfNodes(); ++node_id) {
      inv_Ar[node_id] = inverse(Ar[node_id]);
    }

    return inv_Ar;
  }

  int
  mapP(CellId j) const
  {
    // return (1 + Dimension) * m_implicit_cell_index[j];
    return m_implicit_cell_index[j];
  }

  int
  mapU(size_t k, CellId j) const
  {
    // return (1 + Dimension) * m_implicit_cell_index[j] + k + 1;
    return m_number_of_implicit_cells + k + m_implicit_cell_index[j] * Dimension;
  }

  CRSMatrixDescriptor<double>
  _getA() const
  {
    count_getA++;
    static bool is_get_A_started = false;
    if (not is_get_A_started) {
      is_get_A_started = true;
      get_A_t.stop();
    }
    get_A_t.start();

    Array<int> non_zeros{(Dimension + 1) * m_number_of_implicit_cells};

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    auto is_boundary_node = m_mesh.connectivity().isBoundaryNode();

    non_zeros.fill(0);
    for (CellId cell_j_id = 0; cell_j_id < m_mesh.numberOfCells(); ++cell_j_id) {
      if (m_is_implicit_cell[cell_j_id]) {
        const auto& cell_j_nodes = cell_to_node_matrix[cell_j_id];
        std::vector<size_t> nb_neighbors;
        nb_neighbors.reserve(30);

        for (size_t id_node = 0; id_node < cell_j_nodes.size(); ++id_node) {
          if (not is_boundary_node[cell_j_nodes[id_node]]) {
            NodeId node_id        = cell_j_nodes[id_node];
            const auto& node_cell = node_to_cell_matrix[node_id];
            for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
              CellId id_cell_i = node_cell[cell_i];
              if (m_is_implicit_cell[id_cell_i]) {
                nb_neighbors.push_back(m_implicit_cell_index[id_cell_i]);
              }
            }
          }
        }
        std::sort(nb_neighbors.begin(), nb_neighbors.end());
        auto last = std::unique(nb_neighbors.begin(), nb_neighbors.end());
        nb_neighbors.resize(std::distance(nb_neighbors.begin(), last));

        size_t line_index_p     = mapP(cell_j_id);
        non_zeros[line_index_p] = Dimension * nb_neighbors.size() + 1;

        for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
          size_t line_index_u     = mapU(i_dimension, cell_j_id);
          non_zeros[line_index_u] = nb_neighbors.size() + 1;
        }

        nb_neighbors.clear();
      }
    }

    CRSMatrixDescriptor A{(Dimension + 1) * m_number_of_implicit_cells, non_zeros};
    const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

    for (NodeId node_id = 0; node_id < m_mesh.numberOfNodes(); ++node_id) {
      if (not is_boundary_node[node_id]) {
        const auto& node_cell                      = node_to_cell_matrix[node_id];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

        for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
          const CellId id_cell_i = node_cell[cell_i];
          if (m_is_implicit_cell[id_cell_i]) {
            const size_t node_nb_in_i = node_local_number_in_its_cells[cell_i];
            for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
              const CellId id_cell_j = node_cell[cell_j];
              if (m_is_implicit_cell[id_cell_j]) {
                const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
                const int j_index_p       = mapP(id_cell_j);

                const Rd Bji = m_Ajr(id_cell_i, node_nb_in_i) * m_inv_Ar[node_id] * m_Djr(id_cell_j, node_nb_in_j);
                for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                  const int i_index_u = mapU(i_dimension, id_cell_i);

                  A(j_index_p, i_index_u) += Bji[i_dimension];
                  A(i_index_u, j_index_p) -= Bji[i_dimension];
                }
              }
            }
          }
        }
      }
    }

    // pressure bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            // const auto& face_list = bc.faceList();

            // const auto& face_local_numbers_in_their_cells = m_mesh.connectivity().faceLocalNumbersInTheirCells();

            // const auto& face_to_cell_matrix = m_mesh.connectivity().faceToCellMatrix();

            // for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
            //   for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            //     const FaceId& face_id = face_list[i_face];
            //     // if (m_is_implicit_node[node_id]) {
            //     const auto& face_cell = face_to_cell_matrix[face_id];
            //     for (size_t cell_i = 0; cell_i < face_cell.size(); ++cell_i) {
            //       const CellId id_cell_i = face_cell[cell_i];
            //       if (m_is_implicit_cell[id_cell_i]) {
            //         // const auto& face_local_number_in_its_cells =
            //         // face_local_numbers_in_their_cells.itemArray(face_id); const size_t node_nb_in_i =
            //         // face_local_number_in_its_cells[cell_i];
            //         int index_p = mapP(id_cell_i);
            //         int index_u = mapU(i_dimension, id_cell_i);

            //         // const Rd coef = m_Ajr(id_cell_i, face_local_number_in_its_cells[node_nb_in_i]) *
            //         // m_inv_Ar[face_id] *
            //         //                 m_Djr(id_cell_i, face_local_number_in_its_cells[node_nb_in_i]);

            //         const Rd coef = zero;
            //         A(index_p, index_u) += coef[i_dimension];
            //         A(index_u, index_p) -= coef[i_dimension];
            //         //}
            //       }
            //     }
            //   }
            // }
          }
        },
        boundary_condition);
    }

    NodeValue<bool> has_boundary_condition{m_mesh.connectivity()};
    has_boundary_condition.fill(false);

    // velocity bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                if (not has_boundary_condition[node_id]) {
                  has_boundary_condition[node_id] = true;
                }
              });
          }
        },
        boundary_condition);
    }

    if constexpr (Dimension > 1) {
      // axis bc
      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<AxisBoundaryCondition, T>) {
              const auto& node_list = bc.nodeList();

              const Rd& t    = bc.direction();
              const Rdxd txt = tensorProduct(t, t);

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];
                if (not has_boundary_condition[node_id]) {
                  const Rdxd inverse_Ar_times_txt            = m_inv_Ar[node_id] * txt;
                  const auto& node_cell                      = node_to_cell_matrix[node_id];
                  const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                  for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                    for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                      const CellId id_cell_i = node_cell[cell_i];
                      if (m_is_implicit_cell[id_cell_i]) {
                        const size_t node_nb_in_i = node_local_number_in_its_cells[cell_i];
                        const int i_index_u       = mapU(i_dimension, id_cell_i);
                        for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                          const CellId id_cell_j = node_cell[cell_j];
                          if (m_is_implicit_cell[id_cell_j]) {
                            const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
                            const int j_index_p       = mapP(id_cell_j);

                            const Rd coef =
                              m_Ajr(id_cell_i, node_nb_in_i) * inverse_Ar_times_txt * m_Djr(id_cell_j, node_nb_in_j);
                            A(j_index_p, i_index_u) += coef[i_dimension];
                            A(i_index_u, j_index_p) -= coef[i_dimension];
                          }
                        }
                      }
                    }
                  }
                  has_boundary_condition[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
    }

    // symmetry bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            const Rd& n    = bc.outgoingNormal();
            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              if (not has_boundary_condition[node_id]) {
                const Rdxd inverse_ArxP                    = m_inv_Ar[node_id] * P;
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                  for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                    const CellId id_cell_i = node_cell[cell_i];
                    if (m_is_implicit_cell[id_cell_i]) {
                      const size_t node_nb_in_i = node_local_number_in_its_cells[cell_i];
                      const int i_index_u       = mapU(i_dimension, id_cell_i);
                      for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                        const CellId id_cell_j = node_cell[cell_j];
                        if (m_is_implicit_cell[id_cell_j]) {
                          const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
                          const int j_index_p       = mapP(id_cell_j);

                          const Rd coef =
                            m_Ajr(id_cell_i, node_nb_in_i) * inverse_ArxP * m_Djr(id_cell_j, node_nb_in_j);
                          A(j_index_p, i_index_u) += coef[i_dimension];
                          A(i_index_u, j_index_p) -= coef[i_dimension];
                        }
                      }
                    }
                  }
                }
                has_boundary_condition[node_id] = true;
              }
            }
          }
        },
        boundary_condition);
    }
    get_A_t.pause();

    return A;
  }

  Vector<double>
  _getU(const CellValue<const double>& p, const CellValue<const Rd>& u)
  {
    Vector<double> Un{(Dimension + 1) * m_number_of_implicit_cells};

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        if (m_is_implicit_cell[j]) {
          size_t i = mapP(j);
          Un[i]    = -p[j];
        }
      });

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        if (m_is_implicit_cell[j]) {
          for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
            size_t i = mapU(i_dimension, j);
            Un[i]    = u[j][i_dimension];
          }
        }
      });

    return Un;
  }

  Vector<double>
  _getF(const Vector<double>& Un, const Vector<double>& Uk, const DiscreteScalarFunction& pi, const double dt)
  {
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    NodeValue<const Rd> ur         = this->_getUr();
    NodeValuePerCell<const Rd> Fjr = this->_getFjr(ur);

    m_tau_iter = [&]() {
      CellValue<double> computed_tau_iter(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          if (m_is_implicit_cell[j]) {
            size_t k             = mapP(j);
            computed_tau_iter[j] = m_g_1_exp_S_Cv_inv_g[j] * std::pow(-Uk[k] + pi[j], -m_inv_gamma[j]);
          }
        });
      return computed_tau_iter;
    }();

    CellValue<double> volume_fluxes_sum{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
        double sum             = 0;
        const auto& cell_nodes = cell_to_node_matrix[cell_id];
        for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
          const NodeId node_id = cell_nodes[i_node];
          sum += dot(m_Djr(cell_id, i_node), ur[node_id]);
        }
        volume_fluxes_sum[cell_id] = sum;
      });

    CellValue<Rd> momentum_fluxes_sum{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
        Rd sum                 = zero;
        const auto& cell_nodes = cell_to_node_matrix[cell_id];
        for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
          sum += Fjr(cell_id, i_node);
        }
        momentum_fluxes_sum[cell_id] = sum;
      });

    Vector<double> F{Un.size()};
    F.fill(0);

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
        if (m_is_implicit_cell[cell_id]) {
          const size_t pj_index = mapP(cell_id);
          F[pj_index] =
            m_Mj[cell_id] / dt *
              (m_g_1_exp_S_Cv_inv_g[cell_id] * std::pow(-Uk[pj_index] + pi[cell_id], -m_inv_gamma[cell_id]) -
               m_tau[cell_id]) -
            volume_fluxes_sum[cell_id];

          for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
            const size_t uj_i_index = mapU(i_dimension, cell_id);

            F[uj_i_index] =
              m_Mj[cell_id] / dt * (Uk[uj_i_index] - Un[uj_i_index]) + momentum_fluxes_sum[cell_id][i_dimension];
          }
        }
      });

    return F;
  }

  CRSMatrixDescriptor<double>
  _getHessianJ(const double dt, const DiscreteScalarFunction& pi, const Vector<double>& Uk)
  {
    auto is_boundary_node = m_mesh.connectivity().isBoundaryNode();
    Array<int> non_zeros{(Dimension + 1) * m_number_of_implicit_cells};

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    non_zeros.fill(0);
    for (CellId cell_id = 0; cell_id < m_number_of_implicit_cells; ++cell_id) {
      if (m_is_implicit_cell[cell_id]) {
        const auto& cell_node = cell_to_node_matrix[cell_id];
        std::vector<size_t> nb_neighboors;
        nb_neighboors.reserve(30);

        for (size_t id_node = 0; id_node < cell_node.size(); ++id_node) {
          if (not is_boundary_node[cell_node[id_node]]) {
            NodeId node_id        = cell_node[id_node];
            const auto& node_cell = node_to_cell_matrix[node_id];
            for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
              CellId id_cell_i = node_cell[cell_i];
              if (m_is_implicit_cell[id_cell_i]) {
                nb_neighboors.push_back(m_implicit_cell_index[id_cell_i]);
              }
            }
          }
        }

        std::sort(nb_neighboors.begin(), nb_neighboors.end());
        auto last = std::unique(nb_neighboors.begin(), nb_neighboors.end());
        nb_neighboors.resize(std::distance(nb_neighboors.begin(), last));

        size_t line_index_p     = mapP(cell_id);
        non_zeros[line_index_p] = nb_neighboors.size();

        for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
          size_t line_index_u     = mapU(i_dimension, cell_id);
          non_zeros[line_index_u] = Dimension * nb_neighboors.size();
        }

        nb_neighboors.clear();
      }
    }

    CRSMatrixDescriptor Hess_J{(Dimension + 1) * m_number_of_implicit_cells,
                               (Dimension + 1) * m_number_of_implicit_cells, non_zeros};

    const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

    count_getGradF++;
    static bool getGradF_is_started = false;
    if (not getGradF_is_started) {
      getGradF_is_started = true;
      getGradF_t.stop();
    }

    const double inv_dt = 1 / dt;
    for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
      if (m_is_implicit_cell[cell_id]) {
        const int i_index_p = mapP(cell_id);
        Hess_J(i_index_p, i_index_p) =
          (inv_dt * m_Mj[cell_id]) * m_inv_gamma[cell_id] * m_tau_iter[cell_id] / (-Uk[i_index_p] + pi[cell_id]);

        for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
          const int i_index_u          = mapU(i_dimension, cell_id);
          Hess_J(i_index_u, i_index_u) = inv_dt * m_Mj[cell_id];
        }
      }
    }
    getGradF_t.start();

    for (NodeId node_id = 0; node_id < m_mesh.numberOfNodes(); ++node_id) {
      const auto& node_cell                      = node_to_cell_matrix[node_id];
      const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

      for (size_t j_cell = 0; j_cell < node_cell.size(); ++j_cell) {
        const CellId cell_id_j = node_cell[j_cell];
        if (m_is_implicit_cell[cell_id_j]) {
          const size_t node_nb_in_j_cell = node_local_number_in_its_cells[j_cell];

          const auto Ajr = m_Ajr(cell_id_j, node_nb_in_j_cell);

          for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
            const int i_index_u = mapU(i_dimension, cell_id_j);
            for (size_t j_dimension = 0; j_dimension < Dimension; ++j_dimension) {
              const int j_index_u = mapU(j_dimension, cell_id_j);
              Hess_J(i_index_u, j_index_u) += Ajr(i_dimension, j_dimension);
            }
          }
        }
      }

      if (not is_boundary_node[node_id]) {
        const auto& inv_Ar = m_inv_Ar[node_id];

        for (size_t j_cell = 0; j_cell < node_cell.size(); ++j_cell) {
          const CellId cell_id_j = node_cell[j_cell];
          if (m_is_implicit_cell[cell_id_j]) {
            const size_t node_nb_in_j_cell = node_local_number_in_its_cells[j_cell];
            const int j_index_p            = mapP(cell_id_j);

            const auto invArDjr = inv_Ar * m_Djr(cell_id_j, node_nb_in_j_cell);

            for (size_t i_cell = 0; i_cell < node_cell.size(); ++i_cell) {
              const CellId cell_id_i = node_cell[i_cell];
              if (m_is_implicit_cell[cell_id_i]) {
                const size_t node_nb_in_cell_i = node_local_number_in_its_cells[i_cell];
                const int i_index_p            = mapP(cell_id_i);
                Hess_J(i_index_p, j_index_p) += dot(m_Djr(cell_id_i, node_nb_in_cell_i), invArDjr);
              }
            }
          }
        }

        for (size_t j_cell = 0; j_cell < node_cell.size(); ++j_cell) {
          const CellId cell_id_j = node_cell[j_cell];
          if (m_is_implicit_cell[cell_id_j]) {
            const size_t node_nb_in_j_cell = node_local_number_in_its_cells[j_cell];

            const auto Ajr_invAr = m_Ajr(cell_id_j, node_nb_in_j_cell) * inv_Ar;

            for (size_t i_cell = 0; i_cell < node_cell.size(); ++i_cell) {
              const CellId cell_id_i = node_cell[i_cell];
              if (m_is_implicit_cell[cell_id_i]) {
                const size_t node_nb_in_i_cell = node_local_number_in_its_cells[i_cell];

                const auto Ajr_invAr_Air = Ajr_invAr * m_Ajr(cell_id_i, node_nb_in_i_cell);

                for (size_t j_dimension = 0; j_dimension < Dimension; ++j_dimension) {
                  const int index_u_j = mapU(j_dimension, cell_id_j);

                  for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                    const int index_u_i = mapU(i_dimension, cell_id_i);
                    Hess_J(index_u_j, index_u_i) -= Ajr_invAr_Air(j_dimension, i_dimension);
                  }
                }
              }
            }
          }
        }
      }
    }
    getGradF_t.pause();

    // presure bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            //   const auto& node_list = bc.faceList();

            //   const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();
            //   const auto& node_to_cell_matrix               = m_mesh.connectivity().nodeToCellMatrix();

            //   for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            //     const NodeId& node_id = node_list[i_node];
            //     // if (m_is_implicit_node[node_id]) {
            //     const auto& node_cell = node_to_cell_matrix[node_id];

            //     for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
            //       const CellId id_cell_i = node_cell[cell_i];
            //       if (m_is_implicit_cell[id_cell_i]) {
            //         const auto& node_local_number_in_its_cells =
            //         node_local_numbers_in_their_cells.itemArray(node_id);

            //         const size_t node_nb_in_i = node_local_number_in_its_cells[cell_i];
            //         const int i_index_p       = mapP(id_cell_i);

            //         for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
            //           const CellId id_cell_j = node_cell[cell_j];
            //           if (m_is_implicit_cell[id_cell_j]) {
            //             const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
            //             const int j_index_p       = mapP(id_cell_j);

            //             Hess_J(i_index_p, j_index_p) +=
            //               dot(m_Djr(id_cell_i, node_nb_in_i), m_inv_Ar[node_id] * m_Djr(id_cell_j, node_nb_in_j));
            //           }
            //         }
            //       }
            //     }
            //     for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
            //       const CellId id_cell_i = node_cell[cell_i];
            //       if (m_is_implicit_cell[id_cell_i]) {
            //         const auto& node_local_number_in_its_cells =
            //         node_local_numbers_in_their_cells.itemArray(node_id);

            //         const size_t node_nb_in_i = node_local_number_in_its_cells[cell_i];
            //         for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
            //           const int i_index_u = mapU(i_dimension, id_cell_i);

            //           Hess_J(i_index_u, i_index_u) += m_Ajr(id_cell_i, node_nb_in_i)(i_dimension, i_dimension);

            //           for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
            //             const CellId id_cell_j = node_cell[cell_j];
            //             if (m_is_implicit_cell[id_cell_j]) {
            //               const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
            //               for (size_t j_dimension = 0; j_dimension < Dimension; ++j_dimension) {
            //                 const int j_index_u = mapU(j_dimension, id_cell_j);

            //                 Hess_J(i_index_u, j_index_u) += -(m_Ajr(id_cell_i, node_nb_in_i) * m_inv_Ar[node_id] *
            //                                                   m_Ajr(id_cell_j, node_nb_in_j))(i_dimension,
            //                                                   j_dimension);
            //               }
            //             }
            //           }
            //         }
            //       }
            //     }
            //   }
          }
        },
        boundary_condition);
    }

    NodeValue<bool> has_boundary_condition{m_mesh.connectivity()};
    has_boundary_condition.fill(false);

    // velocity bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id            = node_list[i_node];
              has_boundary_condition[node_id] = true;
            }
          }
        },
        boundary_condition);
    }

    if constexpr (Dimension > 1) {
      // axis bc
      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<AxisBoundaryCondition, T>) {
              const auto& node_list = bc.nodeList();

              const Rd& t    = bc.direction();
              const Rdxd txt = tensorProduct(t, t);

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                if (not has_boundary_condition[node_id]) {
                  const Rdxd inverse_Ar_times_txt = m_inv_Ar[node_id] * txt;

                  const auto& node_cell = node_to_cell_matrix[node_id];

                  for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                    const CellId id_cell_i = node_cell[cell_i];
                    if (m_is_implicit_cell[id_cell_i]) {
                      const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
                      const size_t node_nb_in_i                  = node_local_number_in_its_cells[cell_i];

                      for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                        const int i_index_u = mapU(i_dimension, id_cell_i);

                        for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                          const CellId id_cell_j = node_cell[cell_j];
                          if (m_is_implicit_cell[id_cell_j]) {
                            const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];

                            for (size_t j_dimension = 0; j_dimension < Dimension; ++j_dimension) {
                              const int j_index_u = mapU(j_dimension, id_cell_j);
                              Hess_J(i_index_u, j_index_u) +=
                                (-m_Ajr(id_cell_i, node_nb_in_i) * inverse_Ar_times_txt *
                                 m_Ajr(id_cell_j, node_nb_in_j))(i_dimension, j_dimension);
                            }
                          }
                        }
                      }
                    }
                  }

                  for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                    const CellId id_cell_i = node_cell[cell_i];
                    if (m_is_implicit_cell[id_cell_i]) {
                      const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
                      const size_t node_nb_in_i                  = node_local_number_in_its_cells[cell_i];

                      const int i_index_p = mapP(id_cell_i);

                      for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                        const CellId id_cell_j = node_cell[cell_j];
                        if (m_is_implicit_cell[id_cell_j]) {
                          const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
                          const int j_index_p       = mapP(id_cell_j);
                          Hess_J(i_index_p, j_index_p) +=
                            dot(m_Djr(id_cell_i, node_nb_in_i), inverse_Ar_times_txt * m_Djr(id_cell_j, node_nb_in_j));
                        }
                      }
                    }
                  }
                  has_boundary_condition[node_id] = true;
                }
              }
            }
          },
          boundary_condition);
      }
    }

    // symmetry bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            const Rd& n    = bc.outgoingNormal();
            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];

              if (not has_boundary_condition[node_id]) {
                const Rdxd inverse_ArxP = m_inv_Ar[node_id] * P;

                const auto& node_cell = node_to_cell_matrix[node_id];

                for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                  const CellId id_cell_i = node_cell[cell_i];
                  if (m_is_implicit_cell[id_cell_i]) {
                    const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
                    const size_t node_nb_in_i                  = node_local_number_in_its_cells[cell_i];

                    for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                      const int i_index_u = mapU(i_dimension, id_cell_i);

                      for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                        const CellId id_cell_j = node_cell[cell_j];
                        if (m_is_implicit_cell[id_cell_j]) {
                          const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];

                          for (size_t j_dimension = 0; j_dimension < Dimension; ++j_dimension) {
                            const int j_index_u = mapU(j_dimension, id_cell_j);
                            Hess_J(i_index_u, j_index_u) += (-m_Ajr(id_cell_i, node_nb_in_i) * inverse_ArxP *
                                                             m_Ajr(id_cell_j, node_nb_in_j))(i_dimension, j_dimension);
                          }
                        }
                      }
                    }
                  }
                }

                for (size_t cell_i = 0; cell_i < node_cell.size(); ++cell_i) {
                  const CellId id_cell_i = node_cell[cell_i];
                  if (m_is_implicit_cell[id_cell_i]) {
                    const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
                    const size_t node_nb_in_i                  = node_local_number_in_its_cells[cell_i];

                    const int i_index_p = mapP(id_cell_i);

                    for (size_t cell_j = 0; cell_j < node_cell.size(); ++cell_j) {
                      const CellId id_cell_j = node_cell[cell_j];
                      if (m_is_implicit_cell[id_cell_j]) {
                        const size_t node_nb_in_j = node_local_number_in_its_cells[cell_j];
                        const int j_index_p       = mapP(id_cell_j);
                        Hess_J(i_index_p, j_index_p) +=
                          dot(m_Djr(id_cell_i, node_nb_in_i), inverse_ArxP * m_Djr(id_cell_j, node_nb_in_j));
                      }
                    }
                  }
                }
                has_boundary_condition[node_id] = true;
              }
            }
          }
        },
        boundary_condition);
    }
    return Hess_J;
  }

  CRSMatrix<double>
  _getGradientF(const CRSMatrix<double, int>& A,
                const Vector<double>& Uk,
                const DiscreteScalarFunction& pi,
                const double dt)
  {
    static bool HJ_A_is_started = false;
    if (not HJ_A_is_started) {
      HJ_A_is_started = true;
      HJ_A_t.stop();
    }

    HJ_A_t.start();

    CRSMatrixDescriptor<double> Hess_J = this->_getHessianJ(dt, pi, Uk);

    CRSMatrix Hess_J_crs = Hess_J.getCRSMatrix();

    CRSMatrix gradient_f = Hess_J_crs - A;

    HJ_A_t.pause();

    return gradient_f;
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::axis: {
        if constexpr (Dimension == 1) {
          throw NormalError("Axis boundary condition makes no sense in dimension 1");
        } else {
          const AxisBoundaryConditionDescriptor& axis_bc_descriptor =
            dynamic_cast<const AxisBoundaryConditionDescriptor&>(*bc_descriptor);

          bc_list.push_back(
            AxisBoundaryCondition{getMeshLineNodeBoundary(*mesh, axis_bc_descriptor.boundaryDescriptor())});
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::symmetry: {
        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);

        bc_list.push_back(
          SymmetryBoundaryCondition{getMeshFlatNodeBoundary(*mesh, sym_bc_descriptor.boundaryDescriptor())});
        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());
          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else if (dirichlet_bc_descriptor.name() == "pressure") {
          const FunctionSymbolId pressure_id = dirichlet_bc_descriptor.rhsSymbolId();

          if constexpr (Dimension == 1) {
            MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, bc_descriptor->boundaryDescriptor());

            Array<const double> node_values =
              InterpolateItemValue<double(Rd)>::template interpolate<ItemType::node>(pressure_id, mesh->xr(),
                                                                                     mesh_node_boundary.nodeList());

            bc_list.emplace_back(PressureBoundaryCondition{mesh_node_boundary.nodeList(), node_values});
          } else {
            constexpr ItemType FaceType = [] {
              if constexpr (Dimension > 1) {
                return ItemType::face;
              } else {
                return ItemType::node;
              }
            }();

            MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(*mesh, bc_descriptor->boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);
            Array<const double> face_values =
              InterpolateItemValue<double(Rd)>::template interpolate<FaceType>(pressure_id, mesh_data.xl(),
                                                                               mesh_face_boundary.faceList());
            bc_list.emplace_back(PressureBoundaryCondition{mesh_face_boundary.faceList(), face_values});
          }
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  void
  _computeGradJU_AU(const DiscreteVectorFunction& u,
                    const DiscreteScalarFunction& p,
                    const DiscreteScalarFunction& pi,
                    const double& dt)
  {
    // const auto node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();
    const auto cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    m_number_of_implicit_cells = 0;
    m_implicit_cell_index      = [&] {
      CellValue<int> implicit_cell_index(m_mesh.connectivity());
      implicit_cell_index.fill(-1);

      for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
        if (m_is_implicit_cell[cell_id]) {
          if (implicit_cell_index[cell_id] != -1) {
            throw NormalError("implicit cell has already an implicit number");
          }
          implicit_cell_index[cell_id] = m_number_of_implicit_cells;
          m_number_of_implicit_cells++;
        }
      }
      return implicit_cell_index;
    }();

    m_is_implicit_node = [&] {
      NodeValue<bool> is_implicit_node(m_mesh.connectivity());
      is_implicit_node.fill(false);

      for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
        if (m_is_implicit_cell[cell_id]) {
          const auto& cell_nodes = cell_to_node_matrix[cell_id];
          for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
            is_implicit_node[cell_nodes[i_node]] = true;
          }
        }
      }

      return is_implicit_node;
    }();

    m_number_of_implicit_nodes = 0;
    m_implicit_node_index      = [&] {
      NodeValue<int> implicit_node_index(m_mesh.connectivity());
      implicit_node_index.fill(-1);

      for (NodeId node_id = 0; node_id < m_mesh.numberOfNodes(); ++node_id) {
        if (m_is_implicit_node[node_id]) {
          if (implicit_node_index[node_id] != -1) {
            throw NormalError("implicit node has already an implicit number");
          }
          implicit_node_index[node_id] = m_number_of_implicit_nodes;
          m_number_of_implicit_nodes++;
        }
      }
      return implicit_node_index;
    }();

    std::cout << "building A: " << std::flush;
    const CRSMatrix A = this->_getA().getCRSMatrix();
    std::cout << "done\n" << std::flush;

    Vector<double> Un = this->_getU(p.cellValues(), u.cellValues());
    Vector<double> Uk = this->_getU(m_predicted_p, m_predicted_u);

    int nb_iter = 0;
    double norm_inf_sol;

    Array<const double> abs_Un = [&]() {
      Array<double> compute_abs_Un{Un.size()};
      parallel_for(
        Un.size(), PUGS_LAMBDA(size_t i) { compute_abs_Un[i] = std::abs(Un[i]); });
      return compute_abs_Un;
    }();

    double norm_inf_Un = max(abs_Un);
    if (m_number_of_implicit_cells > 0) {
      do {
        nb_iter++;

        Vector<double> f = this->_getF(Un, Uk, pi, dt);

        std::cout << "building gradf: " << std::flush;
        CRSMatrix<double> gradient_f = this->_getGradientF(A, Uk, pi, dt);
        std::cout << "done\n" << std::flush;

        auto l2Norm = [](const Vector<double>& x) {
          double sum2 = 0;
          for (size_t i = 0; i < x.size(); ++i) {
            sum2 += x[i] * x[i];
          }
          return std::sqrt(sum2);
        };

        Vector<double> sol{Uk.size()};
        sol.fill(0);

        static bool solver_is_init = false;
        if (not solver_is_init) {
          solver_t.stop();
          solver_is_init = true;
        }

        solver_t.start();
        LinearSolver solver;
        std::cout << "solving linear system: " << std::flush;
        solver.solveLocalSystem(gradient_f, sol, f);
        std::cout << "done\n" << std::flush;
        solver_t.pause();

        std::cout << rang::style::bold << "norm resid = " << l2Norm(f - gradient_f * sol) << rang::style::reset << '\n';
        double theta = 1;

        for (CellId cell_id = 0; cell_id < m_number_of_implicit_cells; ++cell_id) {
          size_t k = mapP(cell_id);
          if (-Uk[k] + theta * sol[k] < -pi[cell_id]) {
            double new_theta = 0.5 * (Uk[k] - pi[cell_id]) / sol[k];
            std::cout << "theta: " << theta << " -> " << new_theta << '\n';
            theta = new_theta;
          }
        }

        Vector<double> U_next = Uk - theta * sol;

        Array<const double> abs_sol = [&]() {
          Array<double> compute_abs_sol{sol.size()};
          parallel_for(
            sol.size(), PUGS_LAMBDA(size_t i) { compute_abs_sol[i] = std::abs(sol[i]); });
          return compute_abs_sol;
        }();

        norm_inf_sol = max(abs_sol);

        std::cout << "nb_iter= " << nb_iter << " | ratio Newton = " << norm_inf_sol / norm_inf_Un << "\n";

        // when it is a hard case we relax newton
        size_t neg_pressure_count = 0;
        double min_pressure       = 0;
        for (CellId cell_id = 0; cell_id < m_number_of_implicit_cells; ++cell_id) {
          size_t k = mapP(cell_id);
          if (-U_next[k] + pi[cell_id] < 0) {
            std::cout << " neg p: cell_id=" << cell_id << '\n';
            ++neg_pressure_count;
            min_pressure = std::min(min_pressure, -U_next[k] + pi[cell_id]);
            U_next[k]    = Uk[k];
          }
        }

        Uk = U_next;

        // if ((count_newton == 0) and (count_Djr == 0) and (count_timestep == 0)) {
        //   auto newt_mesh     = std::make_shared<MeshType>(m_mesh.shared_connectivity(), m_mesh.xr());
        //   double pseudo_time = 1E-8 * count_newton + 1E-3 * count_Djr + count_timestep;

        //   std::vector<std::shared_ptr<const INamedDiscreteData>> output_list;
        //   output_list.push_back(
        //     std::make_shared<NamedItemValueVariant>(std::make_shared<ItemValueVariant>(m_predicted_p),
        //     "predicted_p"));

        //   vtk_writer.writeOnMeshForced(newt_mesh, output_list, pseudo_time);
        // }

        m_predicted_u = [&] {
          CellValue<Rd> predicted_u = copy(u.cellValues());
          for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
            if (m_is_implicit_cell[cell_id]) {
              Rd vector_u = zero;
              for (size_t i_dimension = 0; i_dimension < Dimension; ++i_dimension) {
                vector_u[i_dimension] = Uk[mapU(i_dimension, cell_id)];
              }
              predicted_u[cell_id] = vector_u;
            }
          }
          return predicted_u;
        }();

        m_predicted_p = [&] {
          CellValue<double> predicted_p = copy(p.cellValues());
          for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
            if (m_is_implicit_cell[cell_id]) {
              predicted_p[cell_id] = -Uk[mapP(cell_id)];
            }
          }
          return predicted_p;
        }();

        NodeValue<const Rd> ur = _getUr();

        NodeValue<Rd> newton_xr{m_mesh.connectivity()};
        for (NodeId node_id = 0; node_id < m_mesh.numberOfNodes(); ++node_id) {
          newton_xr[node_id] = m_mesh.xr()[node_id] + dt * ur[node_id];
        }

        auto newt_mesh = std::make_shared<MeshType>(m_mesh.shared_connectivity(), newton_xr);
        ++count_newton;
        // double pseudo_time = 1E-4 * count_newton + 1E-2 * count_Djr + count_timestep;

        // std::vector<std::shared_ptr<const INamedDiscreteData>> output_list;
        // output_list.push_back(
        //   std::make_shared<NamedItemValueVariant>(std::make_shared<ItemValueVariant>(m_predicted_p), "predicted_p"));

        // vtk_writer.writeOnMeshForced(newt_mesh, output_list, pseudo_time);

        if (neg_pressure_count > 0) {
          std::cout << rang::fgB::magenta << "p est negatif sur " << neg_pressure_count
                    << " mailles min=" << min_pressure << rang::fg::reset << '\n';
        }
      } while ((norm_inf_sol > 1e-12 * norm_inf_Un) and (nb_iter < 1000));
    }

    for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
      // faudrait mettre p+pi
      if (m_predicted_p[j] <= 0) {
        std::cout << "pression negative pour la maille" << j << '\n';
      }
    }
  }

  NodeValue<const Rd>
  _getUr() const
  {
    const auto& node_to_cell_matrix               = m_mesh.connectivity().nodeToCellMatrix();
    const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

    NodeValue<Rd> b{m_mesh.connectivity()};

    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        const auto& node_to_cell                   = node_to_cell_matrix[r];
        const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(r);

        Rd br = zero;
        for (size_t j = 0; j < node_to_cell.size(); ++j) {
          const CellId J       = node_to_cell[j];
          const unsigned int R = node_local_number_in_its_cells[j];
          br += m_Ajr(J, R) * m_predicted_u[J] + m_predicted_p[J] * m_Djr(J, R);
        }

        b[r] = br;
      });

    NodeValue<bool> has_boundary_condition{m_mesh.connectivity()};
    has_boundary_condition.fill(false);

    // velocity bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();

            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                const auto& value    = value_list[i_node];
                b[node_id]           = value;

                has_boundary_condition[node_id] = true;
              });
          }
        },
        boundary_condition);
    }

    // pressure bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);
            if constexpr (Dimension == 1) {
              const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();

              const auto& node_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              parallel_for(
                node_list.size(), PUGS_LAMBDA(size_t i_node) {
                  const NodeId node_id       = node_list[i_node];
                  const auto& node_cell_list = node_to_cell_matrix[node_id];
                  Assert(node_cell_list.size() == 1);

                  const CellId node_cell_id              = node_cell_list[0];
                  const size_t node_local_number_in_cell = node_local_numbers_in_their_cells(node_id, 0);

                  b[node_id] -= value_list[i_node] * Cjr(node_cell_id, node_local_number_in_cell);
                });
            } else {
              const NodeValuePerFace<const Rd> Nlr = mesh_data.Nlr();

              const auto& face_to_cell_matrix               = m_mesh.connectivity().faceToCellMatrix();
              const auto& face_to_node_matrix               = m_mesh.connectivity().faceToNodeMatrix();
              const auto& face_local_numbers_in_their_cells = m_mesh.connectivity().faceLocalNumbersInTheirCells();
              const auto& face_cell_is_reversed             = m_mesh.connectivity().cellFaceIsReversed();

              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id       = face_list[i_face];
                const auto& face_cell_list = face_to_cell_matrix[face_id];
                Assert(face_cell_list.size() == 1);

                const CellId face_cell_id              = face_cell_list[0];
                const size_t face_local_number_in_cell = face_local_numbers_in_their_cells(face_id, 0);

                const double sign = face_cell_is_reversed(face_cell_id, face_local_number_in_cell) ? -1 : 1;

                const auto& face_nodes = face_to_node_matrix[face_id];

                for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                  NodeId node_id = face_nodes[i_node];
                  b[node_id] -= sign * value_list[i_face] * Nlr(face_id, i_node);
                }
              }
            }
          }
        },
        boundary_condition);
    }

    if constexpr (Dimension > 1) {
      // axis bc
      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<AxisBoundaryCondition, T>) {
              const Rd& t    = bc.direction();
              const Rdxd txt = tensorProduct(t, t);

              const auto& node_list = bc.nodeList();
              parallel_for(
                bc.numberOfNodes(), PUGS_LAMBDA(const size_t i_node) {
                  const NodeId node_id = node_list[i_node];
                  if (not has_boundary_condition[node_id]) {
                    b[node_id] = txt * b[node_id];

                    has_boundary_condition[node_id] = true;
                  }
                });
            }
          },
          boundary_condition);
      }
    }

    // symmetry bc
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const Rd& n           = bc.outgoingNormal();
            const Rdxd I          = identity;
            const Rdxd nxn        = tensorProduct(n, n);
            const Rdxd P          = I - nxn;
            const auto& node_list = bc.nodeList();
            parallel_for(
              bc.numberOfNodes(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                if (not has_boundary_condition[node_id]) {
                  b[node_id] = P * b[node_id];

                  has_boundary_condition[node_id] = true;
                }
              });
          }
        },
        boundary_condition);
    }

    const NodeValue<Rd> computed_ur(m_mesh.connectivity());
    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) { computed_ur[r] = m_inv_Ar[r] * b[r]; });

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;

          if constexpr ((Dimension > 1) and (std::is_same_v<AxisBoundaryCondition, T>)) {
            const Rd& t    = bc.direction();
            const Rdxd txt = tensorProduct(t, t);

            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];   // on fixe le sommet r
              computed_ur[node_id] = txt * computed_ur[node_id];
            }
          } else if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
            const Rd& n = bc.outgoingNormal();

            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];   // on fixe le sommet r
              computed_ur[node_id] = P * computed_ur[node_id];
            }
          } else if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const Array<const NodeId>& node_list = bc.nodeList();
            const Array<const Rd>& value_list    = bc.valueList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              computed_ur[node_id] = value_list[i_node];
            }
          }
        },
        boundary_condition);
    }

    return computed_ur;
  }

  NodeValuePerCell<const Rd>
  _getFjr(const NodeValue<const Rd>& ur) const
  {
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const NodeValuePerCell<Rd> computed_Fjr(m_mesh.connectivity());
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          computed_Fjr(j, r) = m_Ajr(j, r) * (m_predicted_u[j] - ur[cell_nodes[r]]) + (m_predicted_p[j] * m_Djr(j, r));
        }
      });

    return computed_Fjr;
  }

  ImplicitAcousticSolver(const SolverType solver_type,
                         const std::shared_ptr<const MeshType>& p_mesh,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteVectorFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                         const std::vector<std::shared_ptr<const IZoneDescriptor>>& explicit_zone_list,
                         const double&)
    : m_solver_type{solver_type},
      m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)},
      m_mesh{*p_mesh}
  {
    m_is_implicit_cell = [&] {
      CellValue<bool> is_implicit_cell(m_mesh.connectivity());
      is_implicit_cell.fill(true);

      for (auto explicit_zone : explicit_zone_list) {
        auto mesh_cell_zone   = getMeshCellZone(m_mesh, *explicit_zone);
        const auto& cell_list = mesh_cell_zone.cellList();
        for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
          const CellId cell_id      = cell_list[i_cell];
          is_implicit_cell[cell_id] = false;
        }
      }

      return is_implicit_cell;
    }();
  }

  ImplicitAcousticSolver(const SolverType solver_type,
                         const std::shared_ptr<const MeshType>& p_mesh,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteVectorFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const DiscreteScalarFunction&,
                         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                         const DiscreteScalarFunction& chi_explicit,
                         const double&)
    : m_solver_type{solver_type},
      m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)},
      m_mesh{*p_mesh}
  {
    m_is_implicit_cell = [&] {
      CellValue<bool> is_implicit_cell(m_mesh.connectivity());

      parallel_for(
        m_mesh.numberOfCells(),
        PUGS_LAMBDA(const CellId cell_id) { is_implicit_cell[cell_id] = (chi_explicit[cell_id] == 0); });

      return is_implicit_cell;
    }();
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt,
        const std::shared_ptr<const MeshType>& mesh,
        const DiscreteScalarFunction& rho,
        const DiscreteVectorFunction& u,
        const DiscreteScalarFunction& E,

        const DiscreteScalarFunction& c,
        const DiscreteScalarFunction& p,
        const DiscreteScalarFunction& pi,
        const DiscreteScalarFunction& gamma,
        const DiscreteScalarFunction& Cv,
        const DiscreteScalarFunction& entropy)
  {
    static Timer implicit_t;
    implicit_t.start();

    MeshDataType& mesh_data                = MeshDataManager::instance().getMeshData(*mesh);
    const NodeValuePerCell<const Rd> Cjr_n = mesh_data.Cjr();

    count_Djr = 0;

    m_Djr = copy(Cjr_n);

    CellValue<double> new_rho = copy(rho.cellValues());
    CellValue<Rd> new_u       = copy(u.cellValues());
    CellValue<double> new_E   = copy(E.cellValues());
    std::shared_ptr<const MeshType> new_mesh;

    m_Mj = [&]() {
      const CellValue<const double>& Vj = mesh_data.Vj();
      CellValue<double> computed_Mj(m_mesh.connectivity());

      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_Mj[j] = rho[j] * Vj[j]; });
      return computed_Mj;
    }();

    m_u = u.cellValues();

    m_predicted_u = m_u;
    m_predicted_p = p.cellValues();

    m_tau = [&]() {
      CellValue<double> computed_tau(m_mesh.connectivity());

      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_tau[j] = 1 / rho[j]; });
      return computed_tau;
    }();

    m_Ajr = this->_computeAjr(rho, c);

    NodeValue<const Rdxd> Ar = this->_computeAr(m_Ajr);

    m_inv_Ar = this->_computeInvAr(Ar);

    m_inv_gamma = [&]() {
      CellValue<double> computed_inv_gamma(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_inv_gamma[j] = 1 / gamma[j]; });
      return computed_inv_gamma;
    }();

    m_g_1_exp_S_Cv_inv_g = [&]() {
      CellValue<double> computed_g_1_exp_S_Cv_inv_g(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          if (m_is_implicit_cell[j]) {
            computed_g_1_exp_S_Cv_inv_g[j] = std::pow((gamma[j] - 1) * exp(entropy[j] / Cv[j]), m_inv_gamma[j]);
          }
        });
      return computed_g_1_exp_S_Cv_inv_g;
    }();

    double max_tau_error;
    int number_iter = 0;
    do {
      number_iter++;

      count_newton = 0;

      this->_computeGradJU_AU(u, p, pi, dt);

      NodeValue<const Rd> ur         = this->_getUr();
      NodeValuePerCell<const Rd> Fjr = this->_getFjr(ur);

      // time n+1
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      NodeValue<Rd> new_xr = copy(mesh->xr());
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * ur[r]; });

      new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);

      CellValue<const double> Vj = MeshDataManager::instance().getMeshData(*mesh).Vj();

      new_rho = copy(rho.cellValues());
      new_u   = copy(u.cellValues());
      new_E   = copy(E.cellValues());

      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];

          Rd momentum_fluxes   = zero;
          double energy_fluxes = 0;
          for (size_t R = 0; R < cell_nodes.size(); ++R) {
            const NodeId r = cell_nodes[R];
            momentum_fluxes += Fjr(j, R);
            energy_fluxes += dot(Fjr(j, R), ur[r]);
          }
          const double dt_over_Mj = dt / m_Mj[j];
          new_u[j] -= dt_over_Mj * momentum_fluxes;
          new_E[j] -= dt_over_Mj * energy_fluxes;
        });

      // update Djr
      const NodeValuePerCell<const Rd> new_Cjr = MeshDataManager::instance().getMeshData(*new_mesh).Cjr();

      CellValue<double> new_Vj{m_mesh.connectivity()};

      bool is_positive = true;
      do {
        is_positive = true;
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
            new_Vj[j]       = 0;
            auto cell_nodes = cell_to_node_matrix[j];

            for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
              const NodeId node_id = cell_nodes[i_node];
              new_Vj[j] += dot(new_Cjr[j][i_node], new_xr[node_id]);
            }
            new_Vj[j] *= 1. / Dimension;
          });

        double m = min(new_Vj);
        if (m < 0) {
          std::cout << "negative volume\n";
          parallel_for(
            mesh->numberOfNodes(),
            PUGS_LAMBDA(const NodeId node_id) { new_xr[node_id] = 0.5 * (new_xr[node_id] + mesh->xr()[node_id]); });
          is_positive = false;
        }
      } while (not is_positive);

      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_rho[j] *= Vj[j] / new_Vj[j]; });

      CellValue<double> new_tau(m_mesh.connectivity());

      for (CellId j = 0; j < mesh->numberOfCells(); ++j) {
        double new_tau_fluxes  = 0;
        const auto& cell_nodes = cell_to_node_matrix[j];
        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          const NodeId r = cell_nodes[R];
          new_tau_fluxes += dot(m_Djr(j, R), ur[r]);
        }
        new_tau[j] = m_tau[j] + (dt / m_Mj[j]) * new_tau_fluxes;
        // std::cout << "new_tau(" << j << ")=" << new_tau[j] << '\n';
      }

      //      double sum_tau_rho = 0;
      max_tau_error = 0;

      for (CellId j = 0; j < mesh->numberOfCells(); ++j) {
        // const auto& cell_nodes = cell_to_node_matrix[j];
        // V_j^n+1/tau_j^n+1 - M_j
        //        sum_tau_rho += std::abs(new_tau[j] - 1. / new_rho[j]);
        if (m_is_implicit_cell[j]) {
          max_tau_error = std::max(max_tau_error, std::abs(1 - new_tau[j] * new_rho[j]));
        }
      }
      // std::cout << "sum_tau_rho  =" << sum_tau_rho << '\n';
      // std::cout << "max_tau_error=" << max_tau_error << '\n';

      if constexpr (Dimension == 2) {
        NodeValuePerCell<Rd> Djr{m_mesh.connectivity()};

        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
            const auto& cell_nodes = cell_to_node_matrix[cell_id];
            if (m_is_implicit_cell[cell_id]) {
              for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                Djr(cell_id, i_node) = 0.5 * (Cjr_n(cell_id, i_node) + new_Cjr(cell_id, i_node));
              }
            } else {
              for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                Djr(cell_id, i_node) = Cjr_n(cell_id, i_node);
              }
            }
          });

        m_Djr = Djr;
      } else if constexpr (Dimension == 3) {
        NodeValuePerFace<Rd> Nlr_delta(m_mesh.connectivity());
        const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();

        const auto& xr_n   = m_mesh.xr();
        const auto& xr_np1 = new_xr;

        parallel_for(
          m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId l) {
            const auto& face_nodes = face_to_node_matrix[l];
            const size_t nb_nodes  = face_nodes.size();
            std::vector<Rd> dxr_n(nb_nodes);
            std::vector<Rd> dxr_np1(nb_nodes);
            for (size_t r = 0; r < nb_nodes; ++r) {
              dxr_n[r]   = xr_n[face_nodes[(r + 1) % nb_nodes]] - xr_n[face_nodes[(r + nb_nodes - 1) % nb_nodes]];
              dxr_np1[r] = xr_np1[face_nodes[(r + 1) % nb_nodes]] - xr_np1[face_nodes[(r + nb_nodes - 1) % nb_nodes]];
            }
            const double inv_12_nb_nodes = 1. / (12. * nb_nodes);
            for (size_t r = 0; r < nb_nodes; ++r) {
              Rd Nr = zero;
              for (size_t s = 0; s < nb_nodes; ++s) {
                Nr -= crossProduct((1. / 6) * (2 * (dxr_np1[r] - dxr_n[r]) - (dxr_np1[s] - dxr_n[s])),
                                   xr_np1[face_nodes[s]] - xr_n[face_nodes[s]]);
              }
              Nr *= inv_12_nb_nodes;
              Nlr_delta(l, r) = Nr;
            }
          });

        const auto& cell_to_face_matrix   = m_mesh.connectivity().cellToFaceMatrix();
        const auto& cell_face_is_reversed = m_mesh.connectivity().cellFaceIsReversed();

        NodeValuePerCell<Rd> Djr{m_mesh.connectivity()};
        Djr.fill(zero);

        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            const auto& cell_nodes = cell_to_node_matrix[j];

            const auto& cell_faces       = cell_to_face_matrix[j];
            const auto& face_is_reversed = cell_face_is_reversed.itemArray(j);

            for (size_t L = 0; L < cell_faces.size(); ++L) {
              const FaceId& l        = cell_faces[L];
              const auto& face_nodes = face_to_node_matrix[l];

              auto local_node_number_in_cell = [&](NodeId node_number) {
                for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                  if (node_number == cell_nodes[i_node]) {
                    return i_node;
                  }
                }
                return std::numeric_limits<size_t>::max();
              };

              if (face_is_reversed[L]) {
                for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Djr(j, R) -= Nlr_delta(l, rl);
                }
              } else {
                for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Djr(j, R) += Nlr_delta(l, rl);
                }
              }
            }
          });

        parallel_for(
          Djr.numberOfValues(), PUGS_LAMBDA(const size_t i) { Djr[i] += 0.5 * (Cjr_n[i] + new_Cjr[i]); });

        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
            if (not m_is_implicit_cell[cell_id]) {
              const auto& cell_nodes = cell_to_node_matrix[cell_id];
              for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                Djr(cell_id, i_node) = Cjr_n(cell_id, i_node);
              }
            }
          });

        m_Djr = Djr;

        double max_err = 0;

        for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
          if (m_is_implicit_cell[cell_id]) {
            double delta_V        = new_Vj[cell_id] - Vj[cell_id];
            const auto& node_list = cell_to_node_matrix[cell_id];
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              delta_V -= dt * dot(ur[node_id], Djr[cell_id][i_node]);
            }
            max_err = std::max(max_err, std::abs(delta_V));
          }
        }

        std::cout << rang::fgB::yellow << "Max volume err = " << max_err << rang::fg::reset << '\n';
      }

      if constexpr (Dimension > 1) {
        std::cout << rang::fgB::magenta << "number_iter_Djr=" << number_iter << " max_tau_error=" << max_tau_error
                  << rang::fg::reset << '\n';
      }

      ++count_Djr;

      // std::cout << "new rho=" << new_rho << '\n';
    } while ((Dimension > 1) and (max_tau_error > 1e-4) and (number_iter < 100));
    // std::cout << "prochaine boucle" << '\n';
    implicit_t.pause();
    // std::cout << "getA=" << get_A_t << "(" << count_getA << ")"
    //           << " getF=" << getF_t << "(" << count_getF << ")"
    //           << " getGradF=" << getGradF_t << "(" << count_getGradF << ")"
    //           << " HJ_A=" << HJ_A_t << '\n';
    // std::cout << "solver=" << solver_t << " total= " << implicit_t << '\n';

    ++count_timestep;

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_rho}),
            std::make_shared<DiscreteFunctionVariant>(DiscreteVectorFunction{new_mesh, new_u}),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_E})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt,
        const std::shared_ptr<const DiscreteFunctionVariant>& rho,
        const std::shared_ptr<const DiscreteFunctionVariant>& u,
        const std::shared_ptr<const DiscreteFunctionVariant>& E,
        const std::shared_ptr<const DiscreteFunctionVariant>& c,
        const std::shared_ptr<const DiscreteFunctionVariant>& p,
        const std::shared_ptr<const DiscreteFunctionVariant>& pi,
        const std::shared_ptr<const DiscreteFunctionVariant>& gamma,
        const std::shared_ptr<const DiscreteFunctionVariant>& Cv,
        const std::shared_ptr<const DiscreteFunctionVariant>& entropy)
  {
    std::shared_ptr mesh_v = getCommonMesh({rho, u, E});
    if (not mesh_v) {
      throw NormalError("discrete functions are not defined on the same mesh");
    }

    if (not checkDiscretizationType({rho, u, E}, DiscreteFunctionType::P0)) {
      throw NormalError("acoustic solver expects P0 functions");
    }

    return this->apply(dt, mesh_v->get<MeshType>(), rho->get<DiscreteScalarFunction>(),
                       u->get<DiscreteVectorFunction>(), E->get<DiscreteScalarFunction>(),
                       c->get<DiscreteScalarFunction>(), p->get<DiscreteScalarFunction>(),
                       pi->get<DiscreteScalarFunction>(), gamma->get<DiscreteScalarFunction>(),
                       Cv->get<DiscreteScalarFunction>(), entropy->get<DiscreteScalarFunction>());
  }

  ImplicitAcousticSolver(const SolverType solver_type,
                         const std::shared_ptr<const MeshVariant>& mesh_v,
                         const std::shared_ptr<const DiscreteFunctionVariant>& rho,
                         const std::shared_ptr<const DiscreteFunctionVariant>& c,
                         const std::shared_ptr<const DiscreteFunctionVariant>& u,
                         const std::shared_ptr<const DiscreteFunctionVariant>& p,
                         const std::shared_ptr<const DiscreteFunctionVariant>& pi,
                         const std::shared_ptr<const DiscreteFunctionVariant>& gamma,
                         const std::shared_ptr<const DiscreteFunctionVariant>& Cv,
                         const std::shared_ptr<const DiscreteFunctionVariant>& entropy,
                         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                         const std::vector<std::shared_ptr<const IZoneDescriptor>>& explicit_zone_list,
                         const double& dt)
    : ImplicitAcousticSolver(solver_type,
                             mesh_v->get<MeshType>(),
                             rho->get<DiscreteScalarFunction>(),
                             c->get<DiscreteScalarFunction>(),
                             u->get<DiscreteFunctionP0<const Rd>>(),
                             p->get<DiscreteScalarFunction>(),
                             pi->get<DiscreteScalarFunction>(),
                             gamma->get<DiscreteScalarFunction>(),
                             Cv->get<DiscreteScalarFunction>(),
                             entropy->get<DiscreteScalarFunction>(),
                             bc_descriptor_list,
                             explicit_zone_list,
                             dt)
  {}

  ImplicitAcousticSolver(const SolverType solver_type,
                         const std::shared_ptr<const MeshVariant>& mesh_v,
                         const std::shared_ptr<const DiscreteFunctionVariant>& rho,
                         const std::shared_ptr<const DiscreteFunctionVariant>& c,
                         const std::shared_ptr<const DiscreteFunctionVariant>& u,
                         const std::shared_ptr<const DiscreteFunctionVariant>& p,
                         const std::shared_ptr<const DiscreteFunctionVariant>& pi,
                         const std::shared_ptr<const DiscreteFunctionVariant>& gamma,
                         const std::shared_ptr<const DiscreteFunctionVariant>& Cv,
                         const std::shared_ptr<const DiscreteFunctionVariant>& entropy,
                         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                         const std::shared_ptr<const DiscreteFunctionVariant>& chi_explicit,
                         const double& dt)
    : ImplicitAcousticSolver(solver_type,
                             mesh_v->get<MeshType>(),
                             rho->get<DiscreteScalarFunction>(),
                             c->get<DiscreteScalarFunction>(),
                             u->get<DiscreteFunctionP0<const Rd>>(),
                             p->get<DiscreteScalarFunction>(),
                             pi->get<DiscreteScalarFunction>(),
                             gamma->get<DiscreteScalarFunction>(),
                             Cv->get<DiscreteScalarFunction>(),
                             entropy->get<DiscreteScalarFunction>(),
                             bc_descriptor_list,
                             chi_explicit->get<DiscreteScalarFunction>(),
                             dt)
  {}

  ImplicitAcousticSolver()                         = default;
  ImplicitAcousticSolver(ImplicitAcousticSolver&&) = default;
  ~ImplicitAcousticSolver()                        = default;
};

template <MeshConcept MeshType>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<MeshType>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<Mesh<1>>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const NodeId> m_face_list;

 public:
  const Array<const NodeId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const NodeId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<MeshType>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<MeshType> m_mesh_flat_node_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_flat_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  SymmetryBoundaryCondition(const MeshFlatNodeBoundary<MeshType>& mesh_flat_node_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<MeshType>::AxisBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshLineNodeBoundary<MeshType> m_mesh_line_node_boundary;

 public:
  const Rd&
  direction() const
  {
    return m_mesh_line_node_boundary.direction();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_line_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_line_node_boundary.nodeList();
  }

  AxisBoundaryCondition(const MeshLineNodeBoundary<MeshType>& mesh_line_node_boundary)
    : m_mesh_line_node_boundary(mesh_line_node_boundary)
  {
    ;
  }

  ~AxisBoundaryCondition() = default;
};

template <>
class ImplicitAcousticSolverHandler::ImplicitAcousticSolver<Mesh<1>>::AxisBoundaryCondition
{
 public:
  AxisBoundaryCondition()  = default;
  ~AxisBoundaryCondition() = default;
};

ImplicitAcousticSolverHandler::ImplicitAcousticSolverHandler(
  const SolverType solver_type,
  const std::shared_ptr<const DiscreteFunctionVariant>& rho,
  const std::shared_ptr<const DiscreteFunctionVariant>& c,
  const std::shared_ptr<const DiscreteFunctionVariant>& u,
  const std::shared_ptr<const DiscreteFunctionVariant>& p,
  const std::shared_ptr<const DiscreteFunctionVariant>& pi,
  const std::shared_ptr<const DiscreteFunctionVariant>& gamma,
  const std::shared_ptr<const DiscreteFunctionVariant>& Cv,
  const std::shared_ptr<const DiscreteFunctionVariant>& entropy,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const std::vector<std::shared_ptr<const IZoneDescriptor>>& explicit_zone_list,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({rho, c, u, p, pi, gamma, Cv});
  if (not mesh_v) {
    throw NormalError("discrete functions are not defined on the same mesh");
  }

  if (not checkDiscretizationType({rho, c, u, p}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        m_implicit_acoustic_solver =
          std::make_unique<ImplicitAcousticSolver<MeshType>>(solver_type, mesh_v, rho, c, u, p, pi, gamma, Cv, entropy,
                                                             bc_descriptor_list, explicit_zone_list, dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}

ImplicitAcousticSolverHandler::ImplicitAcousticSolverHandler(
  const SolverType solver_type,
  const std::shared_ptr<const DiscreteFunctionVariant>& rho,
  const std::shared_ptr<const DiscreteFunctionVariant>& c,
  const std::shared_ptr<const DiscreteFunctionVariant>& u,
  const std::shared_ptr<const DiscreteFunctionVariant>& p,
  const std::shared_ptr<const DiscreteFunctionVariant>& pi,
  const std::shared_ptr<const DiscreteFunctionVariant>& gamma,
  const std::shared_ptr<const DiscreteFunctionVariant>& Cv,
  const std::shared_ptr<const DiscreteFunctionVariant>& entropy,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const std::shared_ptr<const DiscreteFunctionVariant>& chi_explicit,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({rho, c, u, p, pi, gamma, Cv, chi_explicit});
  if (not mesh_v) {
    throw NormalError("discrete functions are not defined on the same mesh");
  }

  if (not checkDiscretizationType({rho, c, u, p}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        m_implicit_acoustic_solver =
          std::make_unique<ImplicitAcousticSolver<MeshType>>(solver_type, mesh_v, rho, c, u, p, pi, gamma, Cv, entropy,
                                                             bc_descriptor_list, chi_explicit, dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
