#include <scheme/RelaxedImplicitAcousticSolver.hpp>

#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/MeshFlatNodeBoundary.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/MeshTraits.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <variant>
#include <vector>

template <MeshConcept MeshType>
class RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolver final
  : public RelaxedImplicitAcousticSolverHandler::IRelaxedImplicitAcousticSolver
{
 private:
  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class PressureBoundaryCondition;
  class SymmetryBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition =
    std::variant<PressureBoundaryCondition, SymmetryBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  Vector<double> m_U;
  NodeValue<const double> m_rhocr;

  NodeValue<const double>
  _getRhoCr(const DiscreteScalarFunction& rho, const DiscreteScalarFunction& c) const
  {
    NodeValue<double> rhocr{m_mesh.connectivity()};

    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        const auto& node_cells = node_to_cell_matrix[r];
        const size_t nb_cells  = node_cells.size();
        double rhoc            = 0;

        for (size_t J = 0; J < nb_cells; ++J) {
          CellId j = node_cells[J];
          rhoc += rho[j] * c[j];
        }
        rhocr[r] = rhoc * (1. / nb_cells);
      });

    return rhocr;
  }

  size_t
  transferFunction(int k, CellId j) const
  {
    // return 2 * j + k;
    return k * m_mesh.numberOfCells() + j;
  }

  CRSMatrixDescriptor<double>
  _getA() const
  {
    Array<int> non_zeros{2 * m_mesh.numberOfCells()};
    non_zeros.fill(2);
    CRSMatrixDescriptor A{2 * m_mesh.numberOfCells(), non_zeros};
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    for (CellId j = 0; j < m_mesh.numberOfCells() - 1; ++j) {
      int row                     = transferFunction(0, j);
      int row_low                 = transferFunction(1, j);
      const auto& cell_node       = cell_to_node_matrix[j];
      NodeId r_right              = cell_node[1];
      const auto& node_cell_right = node_to_cell_matrix[r_right];
      CellId j1                   = node_cell_right[1];
      int left_col                = transferFunction(0, j1);
      int right_col               = transferFunction(1, j1);
      // first right block
      A(row, right_col) = -1;
      // second left block
      A(row_low, left_col) = -1;
    }

    for (CellId j = 0; j < m_mesh.numberOfCells() - 1; ++j) {
      int left_col                = transferFunction(0, j);
      int right_col               = transferFunction(1, j);
      const auto& cell_node       = cell_to_node_matrix[j];
      NodeId r_right              = cell_node[1];
      const auto& node_cell_right = node_to_cell_matrix[r_right];
      CellId j1                   = node_cell_right[1];
      int row                     = transferFunction(0, j1);
      int row_low                 = transferFunction(1, j1);
      // first right block
      A(row, right_col) = 1;
      // second left block
      A(row_low, left_col) = 1;
    }

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

            const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

            const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id                      = node_list[i_node];
              const auto& node_cell                      = node_to_cell_matrix[node_id];
              CellId j                                   = node_cell[0];
              const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
              int row                                    = transferFunction(0, j);
              int col                                    = transferFunction(1, j);

              if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                // first right block
                A(row, col) = 1;
                // second left block
                A(col, row) = -1;
              } else {
                // first right block
                A(row, col) = -1;
                // second left block
                A(col, row) = 1;
              }
            }
          } else if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            const auto& node_list = bc.faceList();

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

            const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

            const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id                      = node_list[i_node];
              const auto& node_cell                      = node_to_cell_matrix[node_id];
              CellId j                                   = node_cell[0];
              const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
              int row                                    = transferFunction(0, j);
              int col                                    = transferFunction(1, j);

              if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                // first right block
                A(row, col) = -1;
                // second left block
                A(col, row) = 1;
              } else {
                // first right block
                A(row, col) = 1;
                // second left block
                A(col, row) = -1;
              }
            }
          } else {
            throw UnexpectedError("boundary condition not handled");
          }
        },
        boundary_condition);
    }

    return A;
  }

  Vector<double>
  _getU(const DiscreteScalarFunction& p, const DiscreteVectorFunction& u)
  {
    Vector<double> Un{2 * m_mesh.numberOfCells()};

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        size_t i = transferFunction(0, j);
        Un[i]    = p[j];
      });

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        size_t i = transferFunction(1, j);
        Un[i]    = u[j][0];
      });

    return Un;
  }

  Vector<double>
  _getGradJ(const Vector<double>& Un,
            const Vector<double>& Uk,
            const NodeValue<const double>& rhocr,
            const DiscreteScalarFunction& rho,
            const DiscreteScalarFunction& c,
            const DiscreteScalarFunction& p,
            const double a,
            const double lambda,
            const double dt) const
  {
    const CellValue<const double> Mj = [&]() {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

      const CellValue<const double>& Vj = mesh_data.Vj();
      CellValue<double> computed_Mj(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_Mj[j] = rho[j] * Vj[j]; });
      return computed_Mj;
    }();

    Vector<double> grad_J{2 * m_mesh.numberOfCells()};

    CellValue<double> grad_J1 = [&]() {
      const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
      CellValue<double> computed_gradJ1(m_mesh.connectivity());
      for (CellId j = 1; j < m_mesh.numberOfCells() - 1; ++j) {
        const auto& cell_node       = cell_to_node_matrix[j];
        NodeId r_right              = cell_node[1];
        NodeId r_left               = cell_node[0];
        const auto& node_cell_right = node_to_cell_matrix[r_right];
        const auto& node_cell_left  = node_to_cell_matrix[r_left];
        CellId j1                   = node_cell_right[1];
        CellId j_1                  = node_cell_left[0];

        size_t q   = transferFunction(0, j);
        size_t q_1 = transferFunction(0, j_1);
        size_t q1  = transferFunction(0, j1);

        computed_gradJ1[j] = (2 * Mj[j] / (dt * a * a)) * (Uk[q] - Un[q]) + (1. / rhocr[r_right]) * (Uk[q] - Uk[q1]) +
                             (1. / rhocr[r_left]) * (Uk[q] - Uk[q_1]) +
                             (2 * lambda / (a * a)) * rho[j] * (Uk[q] - p[j]);
      }

      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

              const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

              const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id                      = node_list[i_node];
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                CellId j                                   = node_cell[0];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                  const auto& cell_node      = cell_to_node_matrix[j];
                  NodeId r_left              = cell_node[0];
                  const auto& node_cell_left = node_to_cell_matrix[r_left];
                  CellId j_1                 = node_cell_left[0];
                  size_t q                   = transferFunction(0, j);
                  size_t q_1                 = transferFunction(0, j_1);

                  computed_gradJ1[j] = (2 * Mj[j] / (dt * a * a)) * (Uk[q] - Un[q]) +
                                       (1. / rhocr[r_left]) * (Uk[q] - Uk[q_1]) +
                                       (2 * lambda * rho[j] / (a * a)) * (Uk[q] - p[j]) + 2 * value_list[i_node][0];

                } else {
                  const auto& cell_node       = cell_to_node_matrix[j];
                  NodeId r_right              = cell_node[1];
                  const auto& node_cell_right = node_to_cell_matrix[r_right];
                  CellId j1                   = node_cell_right[1];
                  size_t q                    = transferFunction(0, j);
                  size_t q1                   = transferFunction(0, j1);

                  computed_gradJ1[j] = (2 * Mj[j] / (dt * a * a)) * (Uk[q] - Un[q]) +
                                       (1. / rhocr[r_right]) * (Uk[q] - Uk[q1]) +
                                       (2 * lambda * rho[j] / (a * a)) * (Uk[q] - p[j]) - 2 * value_list[i_node][0];
                }
              }
            } else if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
              const auto& node_list   = bc.faceList();
              const auto& value_list  = bc.valueList();
              MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

              const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

              const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id                      = node_list[i_node];
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                CellId j                                   = node_cell[0];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                  const auto& cell_node      = cell_to_node_matrix[j];
                  NodeId r_left              = cell_node[0];
                  const auto& node_cell_left = node_to_cell_matrix[r_left];
                  CellId j_1                 = node_cell_left[0];
                  size_t q                   = transferFunction(0, j);
                  size_t q_1                 = transferFunction(0, j_1);

                  computed_gradJ1[j] = (2 * Mj[j] / (dt * a * a)) * (Uk[q] - Un[q]) +
                                       (1. / rhocr[r_left]) * (Uk[q] - Uk[q_1]) +
                                       (2 * lambda * rho[j] / (a * a)) * (Uk[q] - p[j]) +
                                       (2. / (rho[j] * c[j])) * (Uk[q] - value_list[i_node]);

                } else {
                  const auto& cell_node       = cell_to_node_matrix[j];
                  NodeId r_right              = cell_node[1];
                  const auto& node_cell_right = node_to_cell_matrix[r_right];
                  CellId j1                   = node_cell_right[1];
                  size_t q                    = transferFunction(0, j);
                  size_t q1                   = transferFunction(0, j1);

                  computed_gradJ1[j] = (2 * Mj[j] / (dt * a * a)) * (Uk[q] - Un[q]) +
                                       (1. / rhocr[r_right]) * (Uk[q] - Uk[q1]) +
                                       (2 * lambda * rho[j] / (a * a)) * (Uk[q] - p[j]) +
                                       (2. / (rho[j] * c[j])) * (Uk[q] - value_list[i_node]);
                }
              }
            }
          },
          boundary_condition);
      }

      return computed_gradJ1;
    }();

    CellValue<double> grad_J2 = [&]() {
      const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
      CellValue<double> computed_gradJ2(m_mesh.connectivity());
      for (CellId j = 1; j < m_mesh.numberOfCells() - 1; ++j) {
        const auto& cell_node       = cell_to_node_matrix[j];
        NodeId r_right              = cell_node[1];
        NodeId r_left               = cell_node[0];
        const auto& node_cell_right = node_to_cell_matrix[r_right];
        const auto& node_cell_left  = node_to_cell_matrix[r_left];
        CellId j1                   = node_cell_right[1];
        CellId j_1                  = node_cell_left[0];

        size_t k   = transferFunction(1, j);
        size_t k_1 = transferFunction(1, j_1);
        size_t k1  = transferFunction(1, j1);

        computed_gradJ2[j] =
          (2 * Mj[j] / dt) * (Uk[k] - Un[k]) + rhocr[r_right] * (Uk[k] - Uk[k1]) + rhocr[r_left] * (Uk[k] - Uk[k_1]);
      }

      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

              const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

              const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id                      = node_list[i_node];
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                CellId j                                   = node_cell[0];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                  const auto& cell_node      = cell_to_node_matrix[j];
                  NodeId r_left              = cell_node[0];
                  const auto& node_cell_left = node_to_cell_matrix[r_left];
                  CellId j_1                 = node_cell_left[0];
                  size_t k                   = transferFunction(1, j);
                  size_t k_1                 = transferFunction(1, j_1);

                  computed_gradJ2[j] = (2 * Mj[j] / dt) * (Uk[k] - Un[k]) + rhocr[r_left] * (Uk[k] - Uk[k_1]) -
                                       2 * rho[j] * c[j] * (value_list[i_node][0] - Uk[k]);

                } else {
                  const auto& cell_node       = cell_to_node_matrix[j];
                  NodeId r_right              = cell_node[1];
                  const auto& node_cell_right = node_to_cell_matrix[r_right];
                  CellId j1                   = node_cell_right[1];
                  size_t k                    = transferFunction(1, j);
                  size_t k1                   = transferFunction(1, j1);

                  computed_gradJ2[j] = (2 * Mj[j] / dt) * (Uk[k] - Un[k]) + rhocr[r_right] * (Uk[k] - Uk[k1]) +
                                       2 * rho[j] * c[j] * (Uk[k] - value_list[i_node][0]);
                }
              }
            } else if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
              const auto& node_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

              const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

              const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id                      = node_list[i_node];
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                CellId j                                   = node_cell[0];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                if (Cjr(j, node_local_number_in_its_cells[0])[0] > 0) {
                  const auto& cell_node      = cell_to_node_matrix[j];
                  NodeId r_left              = cell_node[0];
                  const auto& node_cell_left = node_to_cell_matrix[r_left];
                  CellId j_1                 = node_cell_left[0];
                  size_t k                   = transferFunction(1, j);
                  size_t k_1                 = transferFunction(1, j_1);

                  computed_gradJ2[j] =
                    (2 * Mj[j] / dt) * (Uk[k] - Un[k]) + rhocr[r_left] * (Uk[k] - Uk[k_1]) + 2 * value_list[i_node];

                } else {
                  const auto& cell_node       = cell_to_node_matrix[j];
                  NodeId r_right              = cell_node[1];
                  const auto& node_cell_right = node_to_cell_matrix[r_right];
                  CellId j1                   = node_cell_right[1];
                  size_t k                    = transferFunction(1, j);
                  size_t k1                   = transferFunction(1, j1);

                  computed_gradJ2[j] =
                    (2 * Mj[j] / dt) * (Uk[k] - Un[k]) + rhocr[r_right] * (Uk[k] - Uk[k1]) - 2 * value_list[i_node];
                }
              }
            }
          },
          boundary_condition);
      }

      return computed_gradJ2;
    }();

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        size_t i  = transferFunction(0, j);
        size_t l  = transferFunction(1, j);
        grad_J[i] = grad_J1[j];
        grad_J[l] = grad_J2[j];
      });

    return grad_J;
  }

  Vector<double>
  _getF(const CRSMatrixDescriptor<double>& A,
        const Vector<double>& Un,
        const Vector<double>& Uk,
        const NodeValue<const double>& rhocr,
        const DiscreteScalarFunction& rho,
        const DiscreteScalarFunction& c,
        const DiscreteScalarFunction& p,
        const double a,
        const double lambda,
        const double dt) const
  {
    Vector<double> gradJ = this->_getGradJ(Un, Uk, rhocr, rho, c, p, a, lambda, dt);

    // std::cout << "relaxed grad J" << '\n';
    // for (CellId j = 0; j < 2 * m_mesh.numberOfCells(); ++j) {
    //   std::cout << gradJ[j] << '\n';
    // }
    // std::exit(0);

    CRSMatrix A_crs{A.getCRSMatrix()};
    Vector<double> AU = A_crs * Uk;

    Vector<double> F = gradJ - AU;

    // std::cout << "relaxed F" << '\n';
    // for (CellId j = 0; j < 2 * m_mesh.numberOfCells(); ++j) {
    //   std::cout << F[j] << '\n';
    // }
    // std::exit(0);
    return F;
  }

  CRSMatrixDescriptor<double>
  _getHessianJ(const DiscreteScalarFunction& rho,
               const double dt,
               const DiscreteScalarFunction& c,
               double a,
               double lambda,
               const NodeValue<const double>& rhocr)
  {
    const CellValue<const double> Mj = [&]() {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

      const CellValue<const double>& Vj = mesh_data.Vj();
      CellValue<double> computed_Mj(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_Mj[j] = rho[j] * Vj[j]; });
      return computed_Mj;
    }();

    Array<int> non_zeros{2 * m_mesh.numberOfCells()};
    non_zeros.fill(2);
    CRSMatrixDescriptor Hess_J{2 * m_mesh.numberOfCells(), non_zeros};
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();
    for (CellId j = 1; j < m_mesh.numberOfCells() - 1; ++j) {
      const auto& cell_node = cell_to_node_matrix[j];
      NodeId r_right        = cell_node[1];
      NodeId r_left         = cell_node[0];

      // first left block
      int row = transferFunction(0, j);

      const auto& node_cell_right = node_to_cell_matrix[r_right];
      CellId j1                   = node_cell_right[1];
      int row1                    = transferFunction(0, j1);

      const auto& node_cell_left = node_to_cell_matrix[r_left];
      CellId j_1                 = node_cell_left[0];
      int row_1                  = transferFunction(0, j_1);

      Hess_J(row, row) =
        (2 * Mj[j] / (dt * a * a)) + 1. / rhocr[r_left] + 1. / rhocr[r_right] + (2 * lambda * rho[j] / (a * a));

      Hess_J(row, row1) = -1. / rhocr[r_right];

      Hess_J(row, row_1) = -1. / rhocr[r_left];

      // second right block
      int col   = transferFunction(1, j);
      int col1  = transferFunction(1, j1);
      int col_1 = transferFunction(1, j_1);

      Hess_J(col, col)   = (2 * Mj[j] / dt) + rhocr[r_left] + rhocr[r_right];
      Hess_J(col, col1)  = -rhocr[r_right];
      Hess_J(col, col_1) = -rhocr[r_left];
    }

    CellId j              = 0;
    const auto& cell_node = cell_to_node_matrix[j];
    NodeId r_right        = cell_node[1];

    CellId k               = m_mesh.numberOfCells() - 1;
    const auto& cell_node2 = cell_to_node_matrix[k];
    NodeId r_left          = cell_node2[0];

    // first line left block
    int row_l0                  = transferFunction(0, j);
    const auto& node_cell_right = node_to_cell_matrix[r_right];
    CellId j1                   = node_cell_right[1];
    int row_l01                 = transferFunction(0, j1);

    Hess_J(row_l0, row_l01) = -1. / rhocr[r_right];

    // last line left block
    int row_lN                 = transferFunction(0, k);
    const auto& node_cell_left = node_to_cell_matrix[r_left];
    CellId k_1                 = node_cell_left[0];
    int row_lN_1               = transferFunction(0, k_1);

    Hess_J(row_lN, row_lN_1) = -1. / rhocr[r_left];

    // first line right block
    int row_r0  = transferFunction(1, j);
    int row_r01 = transferFunction(1, j1);

    Hess_J(row_r0, row_r01) = -rhocr[r_right];

    // last line right block
    int row_rN  = transferFunction(1, k);
    int row_rN1 = transferFunction(1, k_1);

    Hess_J(row_rN, row_rN1) = -rhocr[r_left];

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

            const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

            const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id                      = node_list[i_node];
              const auto& node_cell                      = node_to_cell_matrix[node_id];
              CellId j0                                  = node_cell[0];
              const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
              int row_r                                  = transferFunction(1, j0);
              int row_l                                  = transferFunction(0, j0);

              if (Cjr(j0, node_local_number_in_its_cells[0])[0] > 0) {
                const auto& j_nodes = cell_to_node_matrix[j0];
                NodeId r_left0      = j_nodes[0];

                // first left block
                Hess_J(row_l, row_l) =
                  (2 * Mj[j0] / (dt * a * a)) + (2 * lambda * rho[j0] / (a * a)) + 1. / rhocr[r_left0];

                // second right block
                Hess_J(row_r, row_r) = (2 * Mj[j0] / dt) + rhocr[r_left0] + 2 * rho[j0] * c[j0];

              } else {
                const auto& j_nodes = cell_to_node_matrix[j0];
                NodeId r_right0     = j_nodes[1];

                // first left block
                Hess_J(row_l, row_l) =
                  (2 * Mj[j0] / (dt * a * a)) + (2 * lambda * rho[j0] / (a * a)) + 1. / rhocr[r_right0];

                // second right block
                Hess_J(row_r, row_r) = (2 * Mj[j0] / dt) + rhocr[r_right0] + 2 * rho[j0] * c[j0];
              }
            }
          } else if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
            const auto& node_list = bc.faceList();

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(m_mesh);

            const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();

            const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id                      = node_list[i_node];
              const auto& node_cell                      = node_to_cell_matrix[node_id];
              CellId j0                                  = node_cell[0];
              const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);
              int row_r                                  = transferFunction(1, j0);
              int row_l                                  = transferFunction(0, j0);

              if (Cjr(j0, node_local_number_in_its_cells[0])[0] > 0) {
                const auto& cell_node0 = cell_to_node_matrix[j0];
                NodeId r_left0         = cell_node0[0];

                // first left block
                Hess_J(row_l, row_l) = (2 * Mj[j0] / (dt * a * a)) + (2 * lambda * rho[j0] / (a * a)) +
                                       1. / rhocr[r_left0] + 2. / (rho[j0] * c[j0]);

                // second right block
                Hess_J(row_r, row_r) = (2 * Mj[j0] / dt) + rhocr[r_left0];
              } else {
                const auto& cell_node0 = cell_to_node_matrix[j0];
                NodeId r_right0        = cell_node0[1];

                // first left block
                Hess_J(row_l, row_l) = (2 * Mj[j0] / (dt * a * a)) + (2 * lambda * rho[j0] / (a * a)) +
                                       1. / rhocr[r_right0] + 2. / (rho[j0] * c[j0]);

                // second right block
                Hess_J(row_r, row_r) = (2 * Mj[j0] / dt) + rhocr[r_right0];
              }
            }
          }
        },
        boundary_condition);
    }
    return Hess_J;
  }

  CRSMatrix<double>
  _getGradientF(const CRSMatrixDescriptor<double>& A,
                const NodeValue<const double>& rhocr,
                const DiscreteScalarFunction& rho,
                const DiscreteScalarFunction& c,
                const double a,
                const double lambda,
                const double dt)
  {
    CRSMatrixDescriptor<double> Hess_J = this->_getHessianJ(rho, dt, c, a, lambda, rhocr);
    CRSMatrix Hess_J_crs{Hess_J.getCRSMatrix()};
    CRSMatrix A_crs{A.getCRSMatrix()};

    CRSMatrix gradient_f = Hess_J_crs - A_crs;

    return gradient_f;
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);

        bc_list.push_back(
          SymmetryBoundaryCondition{getMeshFlatNodeBoundary(*mesh, sym_bc_descriptor.boundaryDescriptor())});
        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());

          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else if (dirichlet_bc_descriptor.name() == "pressure") {
          const FunctionSymbolId pressure_id = dirichlet_bc_descriptor.rhsSymbolId();

          if constexpr (Dimension == 1) {
            MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, bc_descriptor->boundaryDescriptor());

            Array<const double> node_values =
              InterpolateItemValue<double(Rd)>::template interpolate<ItemType::node>(pressure_id, mesh->xr(),
                                                                                     mesh_node_boundary.nodeList());

            bc_list.emplace_back(PressureBoundaryCondition{mesh_node_boundary.nodeList(), node_values});
          } else {
            constexpr ItemType FaceType = [] {
              if constexpr (Dimension > 1) {
                return ItemType::face;
              } else {
                return ItemType::node;
              }
            }();

            MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(*mesh, bc_descriptor->boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);
            Array<const double> face_values =
              InterpolateItemValue<double(Rd)>::template interpolate<FaceType>(pressure_id, mesh_data.xl(),
                                                                               mesh_face_boundary.faceList());
            bc_list.emplace_back(PressureBoundaryCondition{mesh_face_boundary.faceList(), face_values});
          }
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  RelaxedImplicitAcousticSolver(
    const std::shared_ptr<const MeshType>& p_mesh,
    const DiscreteScalarFunction& rho,
    const DiscreteScalarFunction& c,
    const DiscreteVectorFunction& u,
    const DiscreteScalarFunction& p,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)}, m_mesh{*p_mesh}, m_U{0}
  {
    NodeValue<const double> rhocr = this->_getRhoCr(rho, c);

    double lambda = 1000;
    double a;
    double max_rhoc = 0;

    for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
      if (max_rhoc < (rho[j] * c[j])) {
        max_rhoc = rho[j] * c[j];
      }
    }

    a = 1.1 * max_rhoc;

    const CRSMatrixDescriptor<double>& A = this->_getA();

    // std::cout << "relaxed A" << '\n';
    // for (CellId j = 0; j < 2 * m_mesh.numberOfCells(); ++j) {
    //   for (CellId k = 0; k < 2 * m_mesh.numberOfCells(); ++k) {
    //     std::cout << A(j, k);
    //     std::cout << ' ';
    //   }
    //   std::cout << '\n';
    // }
    // std::exit(0);

    Vector<double> Un = this->_getU(p, u);

    // std::cout << "relaxed U" << '\n';
    // for (CellId j = 0; j < 2 * m_mesh.numberOfCells(); ++j) {
    //   std::cout << Un[j] << '\n';
    // }
    // std::exit(0);

    Vector<double> Uk = copy(Un);
    int nb_iter       = 0;
    double norm_inf_sol;

    Array<const double> abs_Un = [&]() {
      Array<double> compute_abs_Un{Un.size()};
      parallel_for(
        Un.size(), PUGS_LAMBDA(size_t i) { compute_abs_Un[i] = std::abs(Un[i]); });
      return compute_abs_Un;
    }();

    double norm_inf_Un = max(abs_Un);

    do {
      std::cout << "iteration=" << nb_iter << '\n';
      nb_iter++;

      Vector<double> f = this->_getF(A, Un, Uk, rhocr, rho, c, p, a, lambda, dt);

      CRSMatrix<double> gradient_f = this->_getGradientF(A, rhocr, rho, c, a, lambda, dt);

      Vector<double> sol{Un.size()};

      LinearSolver solver;
      solver.solveLocalSystem(gradient_f, sol, f);

      Vector<double> U_next = Uk - sol;

      // std::cout << "relaxed U" << '\n';
      // for (CellId j = 0; j < 2 * m_mesh.numberOfCells(); ++j) {
      //   std::cout << U_next[j] << '\n';
      // }
      // std::exit(0);

      Array<const double> abs_sol = [&]() {
        Array<double> compute_abs_sol{sol.size()};
        parallel_for(
          sol.size(), PUGS_LAMBDA(size_t i) { compute_abs_sol[i] = std::abs(sol[i]); });
        return compute_abs_sol;
      }();

      norm_inf_sol = max(abs_sol);

      Uk = U_next;

      std::cout << "ratio" << norm_inf_sol / norm_inf_Un << "\n";

    } while ((norm_inf_sol > 1e-14 * norm_inf_Un) and (nb_iter < 10000));

    m_U     = Uk;
    m_rhocr = rhocr;
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt,
        const std::shared_ptr<const MeshType>& mesh,
        const DiscreteScalarFunction& rho,
        const DiscreteVectorFunction& u,
        const DiscreteScalarFunction& E) const
  {
    MeshDataType& mesh_data              = MeshDataManager::instance().getMeshData(*mesh);
    const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();

    NodeValue<const Rd> ur = [&]() {
      const auto& node_to_cell_matrix               = m_mesh.connectivity().nodeToCellMatrix();
      const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();

      const NodeValue<Rd> computed_ur(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
          const auto& node_to_cell                  = node_to_cell_matrix[r];
          const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(r);
          Rd sum                                    = zero;

          for (size_t j = 0; j < node_to_cell.size(); ++j) {
            const CellId J       = node_to_cell[j];
            const unsigned int R = node_local_number_in_its_cell[j];
            sum +=
              0.5 * Rd{m_U[transferFunction(1, J)]} - (0.5 / m_rhocr[r]) * (-m_U[transferFunction(0, J)] * Cjr(J, R));
            // std::cout << "calcul u_j+1/2" << '\n';
            // std::cout << "0.5*u=" << 0.5 * U_next[transferFunction(1, J)] << '\n';
            // std::cout << "0.5/alpha*p=" << -0.5 / rhocr[r] * U_next[transferFunction(0, J)] << '\n';
            // std::cout << "Cjr=" << Cjr(J, R) << '\n';
          }
          computed_ur[r] = sum;
        });

      // for (NodeId r = 0; r < m_mesh.numberOfNodes(); ++r) {
      //   const auto& node_to_cell                  = node_to_cell_matrix[r];
      //   const auto& node_local_number_in_its_cell = node_local_numbers_in_their_cells.itemArray(r);
      //   for (size_t j = 0; j < node_to_cell.size(); ++j) {
      //     const CellId J       = node_to_cell[j];
      //     const unsigned int R = node_local_number_in_its_cell[j];
      //     std::cout << "calcul u_j+1/2" << '\n';
      //     std::cout << "0.5*u=" << 0.5 * m_U[transferFunction(1, J)] << '\n';
      //     std::cout << "0.5/alpha*p=" << 0.5 / m_rhocr[r] * m_U[transferFunction(0, J)] << '\n';
      //     std::cout << "Cjr=" << Cjr(J, R) << '\n';
      //   }
      // }
      // std::exit(0);

      // boundary conditions
      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              // MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);
              // const NodeValuePerCell<const Rd>& Cjr = mesh_data.Cjr();
              // const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];
                const auto& node_cell = node_to_cell_matrix[node_id];
                CellId j              = node_cell[0];
                // std::cout << "u_iter" << U_iter[transferFunction(1, j)] << '\n';
                // std::cout << "p_iter" << U_iter[transferFunction(0, j)] << '\n';
                computed_ur[node_list[i_node]][0] = value_list[i_node][0];
              }
            } else if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
              const auto& node_list  = bc.faceList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id                      = node_list[i_node];
                const auto& node_cell                      = node_to_cell_matrix[node_id];
                CellId j                                   = node_cell[0];
                const auto& node_local_number_in_its_cells = node_local_numbers_in_their_cells.itemArray(node_id);

                // std::cout << "value_list" << value_list[i_node] << '\n';
                // std::cout << "Cjr" << Cjr(j, node_local_number_in_its_cells[0])[0] << '\n';
                // std::cout << "u_iter" << U_iter[transferFunction(1, j)] << '\n';
                // std::cout << "p_iter" << U_iter[transferFunction(0, j)] << '\n';
                // std::cout << "rhocr" << rhocr[node_id] << '\n';
                // // computed_ur[node_list[i_node]][0] =
                //   (value_list[i_node] + U_iter[transferFunction(1, j)]) / (rhoj[j] * cj[j]) -
                //   U_iter[transferFunction(0, j)] * Cjr(j, node_local_number_in_its_cells[0])[0];
                // std::cout << "press="
                //           << value_list[i_node] -
                //                (U_iter[transferFunction(0, j)] * Cjr(j, node_local_number_in_its_cells[0]))[0]
                //           << '\n';
                // std::cout << "press/rhocr"
                //           << ((value_list[i_node] -
                //                (U_iter[transferFunction(0, j)] * Cjr(j, node_local_number_in_its_cells[0]))[0]) /
                //               rhocr[node_id])
                //           << '\n';
                computed_ur[node_list[i_node]][0] =
                  m_U[transferFunction(1, j)] -
                  Cjr(j, node_local_number_in_its_cells[0])[0] *
                    ((value_list[i_node] + m_U[transferFunction(0, j)]) / m_rhocr[node_id]);
                // std::cout << "ur=" << computed_ur[node_list[i_node]][0] << '\n';
              }
            }
          },
          boundary_condition);
      }

      return computed_ur;
    }();

    std::cout << "u_j+1/2" << '\n';
    // for (NodeId r = 0; r < mesh->numberOfNodes(); ++r) {
    //   std::cout << "ur[" << r << "]=" << ur[r] << '\n';
    // }

    // std::exit(0);

    //  p_j+1/2
    const NodeValuePerCell<Rd>& Fjr = [&]() {
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
      const NodeValuePerCell<Rd> computed_Fjr(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];

          for (size_t r = 0; r < cell_nodes.size(); ++r) {
            computed_Fjr(j, r) = m_rhocr[cell_nodes[r]] * (Rd{m_U[transferFunction(1, j)]} - ur[cell_nodes[r]]) +
                                 (m_U[transferFunction(0, j)] * Cjr(j, r));
          }
        });

      return computed_Fjr;
    }();
    std::cout << "p_j+1/2" << '\n';
    // for (NodeId r = 0; r < m_mesh.numberOfNodes(); ++r) {
    //   std::cout << "ur[" << r << "]=" << ur[r] << '\n';
    // }

    // for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
    //   const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    //   const auto& cell_nodes          = cell_to_node_matrix[j];
    //   for (size_t r = 0; r < cell_nodes.size(); ++r) {
    //     std::cout << "Fjr(" << j << "," << r << ")=" << Fjr(j, r) << '\n';
    //   }
    // }

    // std::exit(0);

    // time n+1

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    // parallel_for(
    //  mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
    //    const auto& cell_nodes = cell_to_node_matrix[j];

    //    Rd momentum_fluxes   = zero;
    //    double energy_fluxes = 0;
    //    for (size_t R = 0; R < cell_nodes.size(); ++R) {
    //      const NodeId r = cell_nodes[R];
    //      momentum_fluxes += Fjr(j, R);
    //      energy_fluxes += (Fjr(j, R), ur[r]);
    //    }

    //    // std::cout << "momentum" << j << "=" << momentum_fluxes << '\n';
    //    // std::cout << "energy" << j << "=" << energy_fluxes << '\n';
    //    // std::cout << "uj[" << j << "]=" << uj[j] << '\n';
    //    // std::cout << "Ej[" << j << "]=" << Ej[j] << '\n';
    //    uj[j] = uj[j] - (dt * inv_Mj[j]) * momentum_fluxes;
    //    Ej[j] = Ej[j] - (dt * inv_Mj[j]) * energy_fluxes;
    //  });
    // for (CellId j = 0; j < mesh->numberOfCells(); ++j) {
    //   std::cout << "dt/Mj[" << j << "]=" << dt * inv_Mj[j] << '\n';
    //   std::cout << "Ej[" << j << "]=" << Ej[j] << '\n';
    //   std::cout << "uj[" << j << "]=" << uj[j] << '\n';
    // }
    // std::exit(0);
    // parallel_for(
    //   mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { ej[j] = Ej[j] - 0.5 * (uj[j], uj[j]); });

    NodeValue<Rd> new_xr = copy(mesh->xr());
    parallel_for(
      mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * ur[r]; });

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);

    CellValue<const double> Vj = MeshDataManager::instance().getMeshData(*mesh).Vj();

    CellValue<double> new_rho = copy(rho.cellValues());
    CellValue<Rd> new_u       = copy(u.cellValues());
    CellValue<double> new_E   = copy(E.cellValues());

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];

        Rd momentum_fluxes   = zero;
        double energy_fluxes = 0;
        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          const NodeId r = cell_nodes[R];
          momentum_fluxes += Fjr(j, R);
          energy_fluxes += dot(Fjr(j, R), ur[r]);
        }
        const double dt_over_Mj = dt / (rho[j] * Vj[j]);
        new_u[j] -= dt_over_Mj * momentum_fluxes;
        new_E[j] -= dt_over_Mj * energy_fluxes;
      });

    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*new_mesh).Vj();

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_rho[j] *= Vj[j] / new_Vj[j]; });

    std::cout << "prochaine boucle" << '\n';

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_rho}),
            std::make_shared<DiscreteFunctionVariant>(DiscreteVectorFunction{new_mesh, new_u}),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_E})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt,
        const std::shared_ptr<const DiscreteFunctionVariant>& rho,
        const std::shared_ptr<const DiscreteFunctionVariant>& u,
        const std::shared_ptr<const DiscreteFunctionVariant>& E) const
  {
    std::shared_ptr mesh_v = getCommonMesh({rho, u, E});
    if (not mesh_v) {
      throw NormalError("discrete functions are not defined on the same mesh");
    }

    if (not checkDiscretizationType({rho, u, E}, DiscreteFunctionType::P0)) {
      throw NormalError("acoustic solver expects P0 functions");
    }

    return this->apply(dt, mesh_v->get<MeshType>(), rho->get<DiscreteScalarFunction>(),
                       u->get<DiscreteVectorFunction>(), E->get<DiscreteScalarFunction>());
  }

  RelaxedImplicitAcousticSolver(
    const std::shared_ptr<const MeshVariant>& mesh_v,
    const std::shared_ptr<const DiscreteFunctionVariant>& rho,
    const std::shared_ptr<const DiscreteFunctionVariant>& c,
    const std::shared_ptr<const DiscreteFunctionVariant>& u,
    const std::shared_ptr<const DiscreteFunctionVariant>& p,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : RelaxedImplicitAcousticSolver(mesh_v->get<Mesh<Dimension>>(),
                                    rho->get<DiscreteScalarFunction>(),
                                    c->get<DiscreteScalarFunction>(),
                                    u->get<DiscreteVectorFunction>(),
                                    p->get<DiscreteScalarFunction>(),
                                    bc_descriptor_list,
                                    dt)
  {}

  RelaxedImplicitAcousticSolver()                                = default;
  RelaxedImplicitAcousticSolver(RelaxedImplicitAcousticSolver&&) = default;
  ~RelaxedImplicitAcousticSolver()                               = default;
};

template <MeshConcept MeshType>
class RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolver<MeshType>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <>
class RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolver<Mesh<1>>::PressureBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const NodeId> m_face_list;

 public:
  const Array<const NodeId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const Array<const NodeId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolver<MeshType>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<MeshType> m_mesh_flat_node_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_flat_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  SymmetryBoundaryCondition(const MeshFlatNodeBoundary<MeshType>& mesh_flat_node_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

RelaxedImplicitAcousticSolverHandler::RelaxedImplicitAcousticSolverHandler(
  const std::shared_ptr<const DiscreteFunctionVariant>& rho,
  const std::shared_ptr<const DiscreteFunctionVariant>& c,
  const std::shared_ptr<const DiscreteFunctionVariant>& u,
  const std::shared_ptr<const DiscreteFunctionVariant>& p,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({rho, c, u, p});
  if (not mesh_v) {
    throw NormalError("discrete functions are not defined on the same mesh");
  }

  if (not checkDiscretizationType({rho, c, u, p}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (std::is_same_v<MeshType, Mesh<1>>) {
        m_implicit_acoustic_solver =
          std::make_unique<RelaxedImplicitAcousticSolver<MeshType>>(mesh_v, rho, c, u, p, bc_descriptor_list, dt);
      } else {
        throw UnexpectedError("invalid mesh dimension");
      }
    },
    mesh_v->variant());
}
