#include <scheme/VariationalSolver.hpp>

#include <language/utils/InterpolateItemArray.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/ItemValueVariant.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/MeshFlatFaceBoundary.hpp>
#include <mesh/MeshFlatNodeBoundary.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/SubItemValuePerItemVariant.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionDPk.hpp>
#include <scheme/DiscreteFunctionDPkVariant.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/ExternalBoundaryConditionDescriptor.hpp>
#include <scheme/FixedBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/PolynomialReconstruction.hpp>
#include <scheme/PolynomialReconstructionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <analysis/GaussQuadratureDescriptor.hpp>
#include <analysis/QuadratureManager.hpp>
#include <geometry/LineCubicTransformation.hpp>
#include <geometry/LineParabolicTransformation.hpp>

#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/Vector.hpp>

#include <functional>
#include <variant>
#include <vector>

#warning REMOVE WHEN FLUXES ARE PASSED TO THE LANGUAGE
#include <analysis/GaussLegendreQuadratureDescriptor.hpp>
#include <language/utils/IntegrateOnCells.hpp>

template <MeshConcept MeshType, size_t mesh_degree>
struct LineTransformationAccessor
{
};

template <>
struct LineTransformationAccessor<Mesh<2>, 1>
{
  const NodeValue<const TinyVector<2>> m_xr;

  const ItemToItemMatrix<ItemType::face, ItemType::node> m_face_to_node_matrix;

  LineTransformationAccessor(const Mesh<2>& mesh)
    : m_xr{mesh.xr()}, m_face_to_node_matrix(mesh.connectivity().faceToNodeMatrix())
  {}

  auto
  getTransformation(const FaceId face_id) const
  {
    auto node_list = m_face_to_node_matrix[face_id];
    return LineTransformation<2>{m_xr[node_list[0]], m_xr[node_list[1]]};
  }
};

template <size_t mesh_degree>
struct LineTransformationAccessor<PolynomialMesh<2>, mesh_degree>
{
  const NodeValue<const TinyVector<2>> m_xr;
  const FaceArray<const TinyVector<2>> m_xl;

  const ItemToItemMatrix<ItemType::face, ItemType::node> m_face_to_node_matrix;

  LineTransformationAccessor(const PolynomialMesh<2>& mesh)
    : m_xr{mesh.xr()}, m_xl{mesh.xl()}, m_face_to_node_matrix(mesh.connectivity().faceToNodeMatrix())
  {}

  auto
  getTransformation(const FaceId face_id) const
  {
    auto node_list = m_face_to_node_matrix[face_id];
    if constexpr (mesh_degree == 1) {
      return LineTransformation<2>{m_xr[node_list[0]], m_xr[node_list[1]]};
    } else if constexpr (mesh_degree == 2) {
      return LineParabolicTransformation<2>{m_xr[node_list[0]], m_xl[face_id][0], m_xr[node_list[1]]};
    } else if constexpr (mesh_degree == 3) {
      return LineCubicTransformation<2>{m_xr[node_list[0]], m_xl[face_id][0], m_xl[face_id][1], m_xr[node_list[1]]};
    } else {
      static_assert(mesh_degree == 1, "mesh degree is not supported");
    }
  }
};

template <size_t mesh_degree, MeshConcept MeshType>
double
variational_acoustic_dt(const DiscreteFunctionP0<const double>& c, const std::shared_ptr<const MeshType>& p_mesh)
{
  const auto& mesh = *p_mesh;
  const auto Vj    = MeshDataManager::instance().getMeshData(mesh).Vj();

  auto cell_to_face_matrix = mesh.connectivity().cellToFaceMatrix();

  auto qf = QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor(2));

  CellValue<double> Sj{mesh.connectivity()};
  parallel_for(
    mesh.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
      LineTransformationAccessor<MeshType, mesh_degree> transform(*p_mesh);

      double sum     = 0;
      auto face_list = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
        const FaceId face_id = face_list[i_face];

        auto t = transform.getTransformation(face_id);

        for (size_t iq = 0; iq < qf.numberOfPoints(); ++iq) {
          sum += qf.weight(iq) * t.velocityNorm(qf.point(iq));
        }
      }
      Sj[cell_id] = sum;
    });

  CellValue<double> local_dt{mesh.connectivity()};
  parallel_for(
    mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { local_dt[j] = 2 * Vj[j] / (Sj[j] * c[j]); });

  return min(local_dt);
}

double
variational_acoustic_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c_v)
{
  const auto& c = c_v->get<DiscreteFunctionP0<const double>>();

  return std::visit(
    [&](auto&& p_mesh) -> double {
      const auto& mesh = *p_mesh;

      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        if constexpr (MeshType::Dimension == 2) {
          return variational_acoustic_dt<1>(c, p_mesh);
        } else {
          throw NotImplementedError("not implemented in dimension d != 2");
        }
      } else if constexpr (is_polynomial_mesh_v<MeshType>) {
        switch (mesh.degree()) {
        case 1: {
          return variational_acoustic_dt<1>(c, p_mesh);
        }
        case 2: {
          return variational_acoustic_dt<2>(c, p_mesh);
        }
        case 3: {
          return variational_acoustic_dt<3>(c, p_mesh);
        }
        default: {
          throw NotImplementedError("not implemented for mesh degree > 3");
        }
        }
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    c.meshVariant()->variant());
}

template <MeshConcept MeshTypeT>
class VariationalSolverHandler::VariationalSolver final : public VariationalSolverHandler::IVariationalSolver
{
 private:
  using MeshType     = MeshTypeT;
  using MeshDataType = MeshData<MeshType>;

  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FixedBoundaryCondition;
  class PressureBoundaryCondition;
  class SymmetryBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::
    variant<FixedBoundaryCondition, PressureBoundaryCondition, SymmetryBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  mutable FaceArray<const size_t> m_face_indices;
  mutable NodeValue<const size_t> m_node_index;

  const size_t m_quadrature_degree;

  bool
  hasFaceBoundary(const MeshType& mesh, const IBoundaryDescriptor& boundary_descriptor) const
  {
    for (size_t i_ref_face_list = 0;
         i_ref_face_list < mesh.connectivity().template numberOfRefItemList<ItemType::face>(); ++i_ref_face_list) {
      const auto& ref_face_list = mesh.connectivity().template refItemList<ItemType::face>(i_ref_face_list);
      const RefId& ref          = ref_face_list.refId();

      if (ref == boundary_descriptor) {
        auto face_list = ref_face_list.list();
        if (ref_face_list.type() != RefItemListBase::Type::boundary) {
          std::ostringstream ost;
          ost << "invalid boundary " << rang::fgB::yellow << boundary_descriptor << rang::style::reset
              << ": inner faces cannot be used to define mesh boundaries";
          throw NormalError(ost.str());
        }

        return true;
      }
    }
    return false;
  }

  BoundaryConditionList
  _getBCList(const MeshType& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list) const
  {
    BoundaryConditionList bc_list;
    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        bc_list.emplace_back(
          SymmetryBoundaryCondition(getMeshFlatNodeBoundary(mesh, bc_descriptor->boundaryDescriptor()),
                                    getMeshFlatFaceBoundary(mesh, bc_descriptor->boundaryDescriptor())));
        break;
      }
      case IBoundaryConditionDescriptor::Type::fixed: {
        if (hasFaceBoundary(mesh, bc_descriptor->boundaryDescriptor())) {
          bc_list.emplace_back(FixedBoundaryCondition(getMeshNodeBoundary(mesh, bc_descriptor->boundaryDescriptor()),
                                                      getMeshFaceBoundary(mesh, bc_descriptor->boundaryDescriptor())));
        } else {
          bc_list.emplace_back(FixedBoundaryCondition(getMeshNodeBoundary(mesh, bc_descriptor->boundaryDescriptor())));
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const Rd> node_value_list =
            InterpolateItemValue<Rd(Rd)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                               mesh.xr(),
                                                                               mesh_node_boundary.nodeList());

          if (hasFaceBoundary(mesh, bc_descriptor->boundaryDescriptor())) {
            MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(mesh, bc_descriptor->boundaryDescriptor());

            Table<const Rd> face_inner_node_value_list =
              InterpolateItemArray<Rd(Rd)>::template interpolate<ItemType::face>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                                 mesh.xl(),
                                                                                 mesh_face_boundary.faceList());

            bc_list.emplace_back(VelocityBoundaryCondition{mesh_node_boundary, node_value_list,   //
                                                           mesh_face_boundary, face_inner_node_value_list});
          } else {
            bc_list.emplace_back(VelocityBoundaryCondition{mesh_node_boundary, node_value_list});
          }
        } else if (dirichlet_bc_descriptor.name() == "pressure") {
          const FunctionSymbolId pressure_id = dirichlet_bc_descriptor.rhsSymbolId();

          MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(mesh, bc_descriptor->boundaryDescriptor());
          auto face_list                      = mesh_face_boundary.faceList();

          auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

          const QuadratureFormula<1> qf =
            QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor(m_quadrature_degree));

          Table<TinyVector<Dimension>> quadrature_points(mesh_face_boundary.faceList().size(), qf.numberOfPoints());

          LineTransformationAccessor<MeshType, 2> transformation{mesh};

          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            auto t = transformation.getTransformation(face_id);

            auto face_quadrature_points = quadrature_points[i_face];
            for (size_t i_xi = 0; i_xi < qf.numberOfPoints(); ++i_xi) {
              face_quadrature_points[i_xi] = t(qf.point(i_xi));
            }
          }

          Table<double> pressure_at_quadrature_points(mesh_face_boundary.faceList().size(), qf.numberOfPoints());
          EvaluateAtPoints<double(const Rd)>::evaluateTo(pressure_id, quadrature_points, pressure_at_quadrature_points);

          bc_list.emplace_back(PressureBoundaryCondition{mesh_face_boundary, pressure_at_quadrature_points});

        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  void _applyPressureBC(const BoundaryConditionList& bc_list, const MeshType& mesh, Vector<double>& b) const;
  void _applySymmetryBC(const BoundaryConditionList& bc_list,
                        const MeshType& mesh,
                        CRSMatrixDescriptor<double>& A,
                        Vector<double>& b) const;
  void _applyVelocityBC(const BoundaryConditionList& bc_list,
                        const MeshType& mesh,
                        const VariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
                        CRSMatrixDescriptor<double>& A,
                        Vector<double>& b) const;

  void
  _applyBoundaryConditions(const BoundaryConditionList& bc_list,
                           const MeshType& mesh,
                           const VariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
                           CRSMatrixDescriptor<double>& A,
                           Vector<double>& b) const
  {
    this->_applyPressureBC(bc_list, mesh, b);
    this->_applySymmetryBC(bc_list, mesh, A, b);
    this->_applyVelocityBC(bc_list, mesh, velocity_bc_treatment, A, b);
  }

  void _forcebcSymmetryBC(const BoundaryConditionList& bc_list, const MeshType& mesh, Vector<double>& U) const;

  void _forcebcVelocityBC(const BoundaryConditionList& bc_list, const MeshType& mesh, Vector<double>& U) const;

  void
  _forcebcBoundaryConditions(const BoundaryConditionList& bc_list, const MeshType& mesh, Vector<double>& U) const
  {
    this->_forcebcSymmetryBC(bc_list, mesh, U);
    this->_forcebcVelocityBC(bc_list, mesh, U);
  }

  std::tuple<NodeValue<const Rd>, FaceArray<const Rd>, CellValue<const Rd>, CellValue<const double>>
  _computeFluxes(const size_t& degree,
                 const VelocityBCTreatment& velocity_bc_treatment,
                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
                 const std::shared_ptr<const DiscreteFunctionVariant>& u_v,
                 const std::shared_ptr<const DiscreteFunctionVariant>& E_v,
                 const CellValue<const double>& rhoc,
                 const CellValue<const double>& a,
                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                 std::optional<FunctionSymbolId> function_id) const
  {
    std::shared_ptr mesh = getCommonMesh({rho_v, u_v, E_v})->get<MeshType>();

    auto xr = mesh->xr();

    DiscreteScalarFunction rho = rho_v->get<DiscreteScalarFunction>();
    DiscreteVectorFunction u   = u_v->get<DiscreteVectorFunction>();
    DiscreteScalarFunction E   = E_v->get<DiscreteScalarFunction>();

    std::vector<std::shared_ptr<const IBoundaryDescriptor>> symmetry_boundary_descriptor_list;

    for (auto&& bc_descriptor : bc_descriptor_list) {
      if (bc_descriptor->type() == IBoundaryConditionDescriptor::Type::symmetry) {
        symmetry_boundary_descriptor_list.push_back(bc_descriptor->boundaryDescriptor_shared());
      }
    }

    PolynomialReconstructionDescriptor reconstruction_descriptor(IntegrationMethodType::boundary, degree,
                                                                 symmetry_boundary_descriptor_list);
    auto reconstructions = PolynomialReconstruction{reconstruction_descriptor}.build(rho, rho * u, rho * E);

    DiscreteFunctionDPk rho_bar   = reconstructions[0]->template get<DiscreteFunctionDPk<Dimension, const double>>();
    DiscreteFunctionDPk rho_u_bar = reconstructions[1]->template get<DiscreteFunctionDPk<Dimension, const Rd>>();
    DiscreteFunctionDPk rho_E_bar = reconstructions[2]->template get<DiscreteFunctionDPk<Dimension, const double>>();

    auto u_bar = [&rho_bar, &rho_u_bar](const CellId cell_id, const Rd& x) {
      return 1. / rho_bar[cell_id](x) * rho_u_bar[cell_id](x);
    };

    auto p_bar = [&rho_bar, &rho_u_bar, &rho_E_bar](const CellId cell_id, const Rd& x) {
      const double tau         = 1. / rho_bar[cell_id](x);
      const Rd rho_u           = rho_u_bar[cell_id](x);
      const double rho_epsilon = (rho_E_bar[cell_id](x) - 0.5 * tau * dot(rho_u, rho_u));

      constexpr double gamma = 1.4;
      return (gamma - 1) * rho_epsilon;
    };

    auto node_to_face_matrix = mesh->connectivity().nodeToFaceMatrix();
    auto face_to_node_matrix = mesh->connectivity().faceToNodeMatrix();

    m_face_indices = [&]() {
      FaceArray<size_t> face_indices{mesh->connectivity(), mesh->degree() + 1};
      {
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          auto face_nodes = face_to_node_matrix[face_id];

          face_indices[face_id][0]              = face_nodes[0];
          face_indices[face_id][mesh->degree()] = face_nodes[1];
        }
        size_t cpt = mesh->numberOfNodes();
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          for (size_t i_face_node = 1; i_face_node < mesh->degree(); ++i_face_node) {
            face_indices[face_id][i_face_node] = cpt++;
          }
        }
      }
      return face_indices;
    }();

    m_node_index = [&]() {
      NodeValue<size_t> node_index{mesh->connectivity()};
      {
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          auto face_nodes = face_to_node_matrix[face_id];

          node_index[face_nodes[0]] = m_face_indices[face_id][0];
          node_index[face_nodes[1]] = m_face_indices[face_id][mesh->degree()];
        }
      }
      return node_index;
    }();

    Array<int> non_zero(Dimension * (mesh->numberOfNodes() + mesh->numberOfEdges() * (mesh->degree() - 1)));

    parallel_for(
      mesh->numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) {
        const size_t node_idx      = m_node_index[node_id];
        const size_t node_non_zero = (node_to_face_matrix[node_id].size() * mesh->degree() + 1) * 2;

        non_zero[2 * node_idx]     = node_non_zero;
        non_zero[2 * node_idx + 1] = node_non_zero;
      });

    parallel_for(
      mesh->numberOfFaces(), PUGS_LAMBDA(const FaceId face_id) {
        for (size_t i = 1; i < mesh->degree(); ++i) {
          const size_t face_node_idx = m_face_indices[face_id][i];

          const size_t face_node_non_zero = 2 * (mesh->degree() + 1);

          non_zero[2 * face_node_idx]     = face_node_non_zero;
          non_zero[2 * face_node_idx + 1] = face_node_non_zero;
        }
      });

    CRSMatrixDescriptor A_descriptor(Dimension * (mesh->numberOfNodes() + mesh->numberOfEdges() * (mesh->degree() - 1)),
                                     Dimension * (mesh->numberOfNodes() + mesh->numberOfEdges() * (mesh->degree() - 1)),
                                     non_zero);

    auto face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();

    const Rdxd I = identity;

    using R1 = TinyVector<1>;

    const std::vector<std::function<double(R1)>> w_hat_set_P1 = {[](R1 x) -> double { return 0.5 * (1 - x[0]); },
                                                                 [](R1 x) -> double { return 0.5 * (1 + x[0]); }};

    const std::vector<std::function<double(R1)>> w_hat_set_P2 = {[](R1 x) -> double { return 0.5 * x[0] * (x[0] - 1); },
                                                                 [](R1 x) -> double {
                                                                   return -(x[0] - 1) * (x[0] + 1);
                                                                 },
                                                                 [](R1 x) -> double {
                                                                   return 0.5 * x[0] * (x[0] + 1);
                                                                 }};

    std::vector<std::function<double(R1)>> w_hat_set;
    if (mesh->degree() == 1) {
      w_hat_set = w_hat_set_P1;
    } else if (mesh->degree() == 2) {
      w_hat_set = w_hat_set_P2;
    } else {
      throw NotImplementedError("degree > 2");
    }

    const QuadratureFormula<1> qf =
      QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor(m_quadrature_degree));

    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      auto cell_list = face_to_cell_matrix[face_id];

      double Z  = 0;
      double Za = 0;
      for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
        const CellId cell_id = cell_list[i_cell];
        Z += rhoc[cell_id];
        Za += a[cell_id];
      }

      const Rd& x0 = mesh->xr()[face_to_node_matrix[face_id][0]];
      const Rd& x2 = mesh->xr()[face_to_node_matrix[face_id][1]];

      const Rd x1 = [&]() {
        if (mesh->degree() == 2) {
          return mesh->xl()[face_id][0];
        } else {
          return 0.5 * (x0 + x2);
        }
      }();

      const LineParabolicTransformation<Dimension> t(x0, x1, x2);

      for (size_t n0 = 0; n0 < w_hat_set.size(); ++n0) {
        const size_t r = m_face_indices[face_id][n0];

        const auto& w_hat_n0 = w_hat_set[n0];

        for (size_t n1 = 0; n1 < w_hat_set.size(); ++n1) {
          const size_t s = m_face_indices[face_id][n1];

          const auto& w_hat_n1 = w_hat_set[n1];

          Rdxd Al_rs = zero;
          for (size_t i = 0; i < qf.numberOfPoints(); ++i) {
            Al_rs += qf.weight(i) * w_hat_n0(qf.point(i)) * w_hat_n1(qf.point(i))   //
                     * (Z * t.velocityNorm(qf.point(i)) * I +
                        (Za - Z) / t.velocityNorm(qf.point(i)) *
                          tensorProduct(t.velocity(qf.point(i)), t.velocity(qf.point(i))));
          }

          for (size_t i = 0; i < Dimension; ++i) {
            for (size_t j = 0; j < Dimension; ++j) {
              A_descriptor(Dimension * r + i, Dimension * s + j) += Al_rs(i, j);
            }
          }
        }
      }
    }

    const auto& face_local_numbers_in_their_cells = mesh->connectivity().faceLocalNumbersInTheirCells();
    const auto& face_cell_is_reversed             = mesh->connectivity().cellFaceIsReversed();

    Vector<double> b(Dimension * (mesh->numberOfNodes() + mesh->numberOfEdges() * (mesh->degree() - 1)));
    b.fill(0);

    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      const Rd& x0 = mesh->xr()[face_to_node_matrix[face_id][0]];
      const Rd& x2 = mesh->xr()[face_to_node_matrix[face_id][1]];

      const Rd x1 = [&]() {
        if (mesh->degree() == 2) {
          return mesh->xl()[face_id][0];
        } else {
          return 0.5 * (x0 + x2);
        }
      }();

      const LineParabolicTransformation<Dimension> t(x0, x1, x2);
      auto cell_list = face_to_cell_matrix[face_id];

      for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
        const CellId face_cell_id = cell_list[i_cell];

        const size_t i_face_in_cell = face_local_numbers_in_their_cells[face_id][i_cell];

        const double sign = face_cell_is_reversed(face_cell_id, i_face_in_cell) ? -1 : 1;

        for (size_t n0 = 0; n0 < w_hat_set.size(); ++n0) {
          const size_t r = m_face_indices[face_id][n0];

          TinyVector<Dimension> bjl_r = zero;

          const auto& w_hat_n0 = w_hat_set[n0];

          for (size_t i = 0; i < qf.numberOfPoints(); ++i) {
            const TinyVector<Dimension> Tl = t.velocity(qf.point(i));
            const TinyVector<Dimension> Nl{Tl[1], -Tl[0]};

            bjl_r += qf.weight(i) * w_hat_n0(qf.point(i)) * p_bar(face_cell_id, t(qf.point(i))) * sign * Nl;

            bjl_r += qf.weight(i) * w_hat_n0(qf.point(i))   //
                     * (rhoc[face_cell_id] * t.velocityNorm(qf.point(i)) * I +
                        (a[face_cell_id] - rhoc[face_cell_id]) / t.velocityNorm(qf.point(i)) *
                          tensorProduct(t.velocity(qf.point(i)), t.velocity(qf.point(i)))) *
                     u_bar(face_cell_id, t(qf.point(i)));
          }

          for (size_t i = 0; i < Dimension; ++i) {
            b[Dimension * r + i] += bjl_r[i];
          }
        }
      }
    }

    const BoundaryConditionList bc_list = this->_getBCList(*mesh, bc_descriptor_list);

    this->_applyBoundaryConditions(bc_list, *mesh, velocity_bc_treatment, A_descriptor, b);

    CRSMatrix A = A_descriptor.getCRSMatrix();

    Vector<double> U{b.size()};
    U = zero;

    LinearSolver solver;
    solver.solveLocalSystem(A, U, b);
    this->_forcebcBoundaryConditions(bc_list, *mesh, U);

    const auto& cell_to_face_matrix = mesh->connectivity().cellToFaceMatrix();

    CellValue<Rd> momentum_fluxes{mesh->connectivity()};
    CellValue<double> energy_fluxes{mesh->connectivity()};

    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      Rd rho_u_j_fluxes     = zero;
      double rho_E_j_fluxes = 0;

      auto face_list = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
        const FaceId face_id = face_list[i_face];

        const Rd& x0 = mesh->xr()[face_to_node_matrix[face_id][0]];
        const Rd& x2 = mesh->xr()[face_to_node_matrix[face_id][1]];

        const Rd x1 = [&]() {
          if (mesh->degree() == 2) {
            return mesh->xl()[face_id][0];
          } else {
            return 0.5 * (x0 + x2);
          }
        }();

        const LineParabolicTransformation<Dimension> t(x0, x1, x2);

        const double sign = face_cell_is_reversed(cell_id, i_face) ? -1 : 1;

        TinyVector<Dimension> face_momentum_flux = zero;
        double face_energy_flux                  = 0;

        for (size_t iq = 0; iq < qf.numberOfPoints(); ++iq) {
          const TinyVector<Dimension> Tl = t.velocity(qf.point(iq));
          const TinyVector<Dimension> Nl{Tl[1], -Tl[0]};

          TinyVector<Dimension> u_star_xi = zero;
          for (size_t n0 = 0; n0 < w_hat_set.size(); ++n0) {
            const size_t r       = m_face_indices[face_id][n0];
            const auto& w_hat_n0 = w_hat_set[n0];
            TinyVector<Dimension> ur;
            for (size_t i = 0; i < Dimension; ++i) {
              ur[i] = U[Dimension * r + i];
            }

            u_star_xi += w_hat_n0(qf.point(iq)) * ur;
          }

          face_momentum_flux += qf.weight(iq) * p_bar(cell_id, t(qf.point(iq))) * sign * Nl;
          face_momentum_flux += qf.weight(iq)   //
                                * (rhoc[cell_id] * t.velocityNorm(qf.point(iq)) * I +
                                   (a[cell_id] - rhoc[cell_id]) / t.velocityNorm(qf.point(iq)) *
                                     tensorProduct(t.velocity(qf.point(iq)), t.velocity(qf.point(iq)))) *
                                (u_bar(cell_id, t(qf.point(iq))) - u_star_xi);

          face_energy_flux += qf.weight(iq) * p_bar(cell_id, t(qf.point(iq))) * sign * dot(Nl, u_star_xi);
          face_energy_flux += dot(u_star_xi, qf.weight(iq)   //
                                               * (rhoc[cell_id] * t.velocityNorm(qf.point(iq)) * I +
                                                  (a[cell_id] - rhoc[cell_id]) / t.velocityNorm(qf.point(iq)) *
                                                    tensorProduct(t.velocity(qf.point(iq)), t.velocity(qf.point(iq)))) *
                                               (u_bar(cell_id, t(qf.point(iq))) - u_star_xi));
        }

        rho_u_j_fluxes += face_momentum_flux;
        rho_E_j_fluxes += face_energy_flux;
      }

      momentum_fluxes[cell_id] = rho_u_j_fluxes;
      energy_fluxes[cell_id]   = rho_E_j_fluxes;
    }

    if (function_id.has_value()) {
      auto Vj = MeshDataManager::instance().getMeshData(*mesh).Vj();

      GaussLegendreQuadratureDescriptor quadrature(7);
      CellValue<double> S(mesh->connectivity());

      IntegrateOnCells<double(Rd)>::integrateTo(function_id.value(), quadrature, *mesh, S);

      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        energy_fluxes[cell_id] -= S[cell_id];
      }
    }

    NodeValue<Rd> ur(mesh->connectivity());
    FaceArray<Rd> ul(mesh->connectivity(), mesh->degree() - 1);

    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      for (size_t i = 0; i < Dimension; ++i) {
        ur[face_to_node_matrix[face_id][0]][i] = U[Dimension * m_face_indices[face_id][0] + i];
      }
      for (size_t i = 0; i < Dimension; ++i) {
        for (size_t il = 0; il < mesh->degree() - 1; ++il) {
          ul[face_id][il][i] = U[Dimension * m_face_indices[face_id][il + 1] + i];
        }
      }
      for (size_t i = 0; i < Dimension; ++i) {
        ur[face_to_node_matrix[face_id][mesh->degree() - 1]][i] =
          U[Dimension * m_face_indices[face_id][mesh->degree()] + i];
      }
    }

    return {ur, ul, momentum_fluxes, energy_fluxes};
  }

  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  _applyFluxes(const double& dt,
               const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
               const std::shared_ptr<const DiscreteFunctionVariant>& u_v,
               const std::shared_ptr<const DiscreteFunctionVariant>& E_v,
               const NodeValue<const Rd>& ur,
               const FaceArray<const Rd>& ul,
               const CellValue<const Rd>& momentum_fluxes,
               const CellValue<const double>& energy_fluxes) const
  {
    std::shared_ptr mesh = getCommonMesh({rho_v, u_v, E_v})->get<MeshType>();

    auto xr = mesh->xr();

    DiscreteScalarFunction rho = rho_v->get<DiscreteScalarFunction>();
    DiscreteVectorFunction u   = u_v->get<DiscreteVectorFunction>();
    DiscreteScalarFunction E   = E_v->get<DiscreteScalarFunction>();

    CellValue<double> new_rho = copy(rho.cellValues());
    CellValue<Rd> new_u       = copy(u.cellValues());
    CellValue<double> new_E   = copy(E.cellValues());

    MeshDataType& mesh_data    = MeshDataManager::instance().getMeshData(*mesh);
    CellValue<const double> Vj = mesh_data.Vj();

    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      const double dt_over_Mj = dt / (rho[cell_id] * Vj[cell_id]);
      new_u[cell_id] -= dt_over_Mj * momentum_fluxes[cell_id];
      new_E[cell_id] -= dt_over_Mj * energy_fluxes[cell_id];
    }

    NodeValue<Rd> new_xr = copy(mesh->xr());
    FaceArray<Rd> new_xl = copy(mesh->xl());

    parallel_for(
      mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { new_xr[node_id] += dt * ur[node_id]; });

    parallel_for(
      mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
        for (size_t i = 0; i < mesh->degree() - 1; ++i) {
          new_xl[face_id][i] += dt * ul[face_id][i];
        }
      });

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr, new_xl);

    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*new_mesh).Vj();

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_rho[j] *= Vj[j] / new_Vj[j]; });

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(new_mesh, new_rho)),
            std::make_shared<DiscreteFunctionVariant>(DiscreteVectorFunction(new_mesh, new_u)),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(new_mesh, new_E))};
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const size_t& degree,
        const double& dt,
        const VelocityBCTreatment& velocity_bc_treatment,
        const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& u_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& E_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& a_v,
        const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
        std::optional<FunctionSymbolId> function_id) const
  {
    std::shared_ptr mesh = getCommonMesh({rho_v, u_v, E_v, c_v, a_v})->get<MeshType>();

    const CellValue<const double> a = a_v->get<DiscreteFunctionP0<const double>>().cellValues();
    const CellValue<const double> rhoc =
      (rho_v->get<DiscreteFunctionP0<const double>>() * c_v->get<DiscreteFunctionP0<const double>>()).cellValues();

#if 0
    // Heun's RK3 method

    auto [ur1, ul1, momentum_fluxes1, energy_fluxes1] =
      _computeFluxes(degree, velocity_bc_treatment, rho_v, u_v, E_v, rhoc, a, bc_descriptor_list, function_id);

    auto [mesh1, rho1, u1, E1] = _applyFluxes(dt / 3., rho_v, u_v, E_v, ur1, ul1, momentum_fluxes1, energy_fluxes1);

    auto [ur2, ul2, momentum_fluxes2, energy_fluxes2] =
      _computeFluxes(degree, velocity_bc_treatment, rho1, u1, E1, rhoc, a, bc_descriptor_list, function_id);

    auto [mesh2, rho2, u2, E2] =
      _applyFluxes(2. / 3. * dt, rho_v, u_v, E_v, ur2, ul2, momentum_fluxes2, energy_fluxes2);

    auto [ur3, ul3, momentum_fluxes3, energy_fluxes3] =
      _computeFluxes(degree, velocity_bc_treatment, rho2, u2, E2, rhoc, a, bc_descriptor_list, function_id);

    NodeValue<Rd> ur{mesh->connectivity()};
    FaceArray<Rd> ul{mesh->connectivity(), ul1.sizeOfArrays()};
    CellValue<Rd> momentum_fluxes{mesh->connectivity()};
    CellValue<double> energy_fluxes{mesh->connectivity()};

    for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
      ur[node_id] = 0.25 * ur1[node_id] + 0.75 * ur3[node_id];
    }
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      for (size_t i = 0; i < ul.sizeOfArrays(); ++i) {
        ul[face_id][i] = 0.25 * ul1[face_id][i] + 0.75 * ul3[face_id][i];
      }
    }
    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      momentum_fluxes[cell_id] = 0.25 * momentum_fluxes1[cell_id] + 0.75 * momentum_fluxes3[cell_id];
    }
    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      energy_fluxes[cell_id] = 0.25 * energy_fluxes1[cell_id] + 0.75 * energy_fluxes3[cell_id];
    }

    auto [mesh3, rho3, u3, E3] = _applyFluxes(dt, rho_v, u_v, E_v, ur, ul, momentum_fluxes, energy_fluxes);

    return {mesh3, rho3, u3, E3};

#elif 0
    // Ralston's RK3 method

    auto [ur1, ul1, momentum_fluxes1, energy_fluxes1] =
      _computeFluxes(degree, velocity_bc_treatment, rho_v, u_v, E_v, rhoc, a, bc_descriptor_list, function_id);

    auto [mesh1, rho1, u1, E1] = _applyFluxes(dt / 2., rho_v, u_v, E_v, ur1, ul1, momentum_fluxes1, energy_fluxes1);

    auto [ur2, ul2, momentum_fluxes2, energy_fluxes2] =
      _computeFluxes(degree, velocity_bc_treatment, rho1, u1, E1, rhoc, a, bc_descriptor_list, function_id);

    auto [mesh2, rho2, u2, E2] =
      _applyFluxes(3. / 4. * dt, rho_v, u_v, E_v, ur2, ul2, momentum_fluxes2, energy_fluxes2);

    auto [ur3, ul3, momentum_fluxes3, energy_fluxes3] =
      _computeFluxes(degree, velocity_bc_treatment, rho2, u2, E2, rhoc, a, bc_descriptor_list, function_id);

    NodeValue<Rd> ur{mesh->connectivity()};
    FaceArray<Rd> ul{mesh->connectivity(), ul1.sizeOfArrays()};
    CellValue<Rd> momentum_fluxes{mesh->connectivity()};
    CellValue<double> energy_fluxes{mesh->connectivity()};

    for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
      ur[node_id] = 2. / 9 * ur1[node_id] + 1. / 3 * ur2[node_id] + 4. / 9 * ur3[node_id];
    }
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      for (size_t i = 0; i < ul.sizeOfArrays(); ++i) {
        ul[face_id][i] = 2. / 9 * ul1[face_id][i] + 1. / 3 * ul2[face_id][i] + 4. / 9 * ul3[face_id][i];
      }
    }
    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      momentum_fluxes[cell_id] =
        2. / 9 * momentum_fluxes1[cell_id] + 1. / 3 * momentum_fluxes2[cell_id] + 4. / 9 * momentum_fluxes3[cell_id];
    }
    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      energy_fluxes[cell_id] =
        2. / 9 * energy_fluxes1[cell_id] + 1. / 3 * energy_fluxes2[cell_id] + 4. / 9 * energy_fluxes3[cell_id];
    }

    auto [mesh3, rho3, u3, E3] = _applyFluxes(dt, rho_v, u_v, E_v, ur, ul, momentum_fluxes, energy_fluxes);

    return {mesh3, rho3, u3, E3};

#else
    auto [ur1, ul1, momentum_fluxes1, energy_fluxes1] =
      _computeFluxes(degree, velocity_bc_treatment, rho_v, u_v, E_v, rhoc, a, bc_descriptor_list, function_id);

    return _applyFluxes(dt, rho_v, u_v, E_v, ur1, ul1, momentum_fluxes1, energy_fluxes1);
#endif
  }

  VariationalSolver() : m_quadrature_degree(8) {}

  VariationalSolver(VariationalSolver&&) = default;
  ~VariationalSolver()                   = default;
};

template <MeshConcept MeshType>
void
VariationalSolverHandler::VariationalSolver<MeshType>::_applyPressureBC(const BoundaryConditionList& bc_list,
                                                                        const MeshType& mesh,
                                                                        Vector<double>& b) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
          if constexpr (Dimension > 1) {
            const QuadratureFormula<1> qf =
              QuadratureManager::instance().getLineFormula(GaussQuadratureDescriptor(m_quadrature_degree));

            using R1 = TinyVector<1>;

            const std::vector<std::function<double(R1)>> w_hat_set_P1 =
              {[](R1 x) -> double { return 0.5 * (1 - x[0]); }, [](R1 x) -> double { return 0.5 * (1 + x[0]); }};

            const std::vector<std::function<double(R1)>> w_hat_set_P2 =
              {[](R1 x) -> double { return 0.5 * x[0] * (x[0] - 1); },
               [](R1 x) -> double { return -(x[0] - 1) * (x[0] + 1); },
               [](R1 x) -> double { return 0.5 * x[0] * (x[0] + 1); }};

            std::vector<std::function<double(R1)>> w_hat_set;
            if (mesh.degree() == 1) {
              w_hat_set = w_hat_set_P1;
            } else if (mesh.degree() == 2) {
              w_hat_set = w_hat_set_P2;
            } else {
              throw NotImplementedError("degree > 2");
            }

            const auto& face_to_cell_matrix = mesh.connectivity().faceToCellMatrix();
            const auto& face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

            const Array<const FaceId>& face_list = bc.faceList();
            const Table<const double>& p_ext     = bc.valueList();

            const auto& face_local_numbers_in_their_cells = mesh.connectivity().faceLocalNumbersInTheirCells();

            const auto& face_cell_is_reversed = mesh.connectivity().cellFaceIsReversed();

            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              FaceId face_id = face_list[i_face];

              const Rd& x0 = mesh.xr()[face_to_node_matrix[face_id][0]];
              const Rd& x2 = mesh.xr()[face_to_node_matrix[face_id][1]];

              const Rd x1 = [&]() {
                if (mesh.degree() == 2) {
                  return mesh.xl()[face_id][0];
                } else {
                  return 0.5 * (x0 + x2);
                }
              }();

              const LineParabolicTransformation<Dimension> t(x0, x1, x2);
              auto cell_list = face_to_cell_matrix[face_id];

              for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
                const CellId face_cell_id = cell_list[i_cell];

                const size_t i_face_in_cell = face_local_numbers_in_their_cells[face_id][i_cell];

                const double sign = face_cell_is_reversed(face_cell_id, i_face_in_cell) ? -1 : 1;

                for (size_t n0 = 0; n0 < w_hat_set.size(); ++n0) {
                  const size_t r = m_face_indices[face_id][n0];

                  TinyVector<Dimension> bjl_r = zero;

                  const auto& w_hat_n0 = w_hat_set[n0];

                  for (size_t i = 0; i < qf.numberOfPoints(); ++i) {
                    const TinyVector<Dimension> Tl = t.velocity(qf.point(i));
                    const TinyVector<Dimension> Nl{Tl[1], -Tl[0]};

                    bjl_r += qf.weight(i) * w_hat_n0(qf.point(i)) * p_ext(i_face, i) * sign * Nl;
                  }

                  for (size_t i = 0; i < Dimension; ++i) {
                    b[Dimension * r + i] -= bjl_r[i];
                  }
                }
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
VariationalSolverHandler::VariationalSolver<MeshType>::_forcebcSymmetryBC(const BoundaryConditionList& bc_list,
                                                                          const MeshType& mesh,
                                                                          Vector<double>& U) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
          const Rd& n = bc.outgoingNormal();

          const Rdxd I   = identity;
          const Rdxd nxn = tensorProduct(n, n);
          const Rdxd P   = I - nxn;

          const Array<const NodeId>& node_list = bc.nodeList();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const size_t r = m_node_index[node_list[i_node]];

            TinyVector<Dimension> ur;

            for (size_t i = 0; i < Dimension; ++i) {
              ur[i] = U[Dimension * r + i];
            }

            ur = P * ur;

            for (size_t i = 0; i < Dimension; ++i) {
              U[Dimension * r + i] = ur[i];
            }
          }

          const size_t degree = mesh.degree();

          const Array<const FaceId>& face_list = bc.faceList();

          // treat inner nodes of the faces
          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
              const size_t r = m_face_indices[face_id][r_hat];

              TinyVector<Dimension> ur;

              for (size_t i = 0; i < Dimension; ++i) {
                ur[i] = U[Dimension * r + i];
              }

              ur = P * ur;

              for (size_t i = 0; i < Dimension; ++i) {
                U[Dimension * r + i] = ur[i];
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
VariationalSolverHandler::VariationalSolver<MeshType>::_forcebcVelocityBC(const BoundaryConditionList& bc_list,
                                                                          const MeshType& mesh,
                                                                          Vector<double>& U) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;

        if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
          const Array<const NodeId>& node_list   = bc.nodeList();
          const Array<const Rd>& node_value_list = bc.nodeValueList();

          // treats vertices of the faces
          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];
            const size_t r       = m_node_index[node_id];
            for (size_t i = 0; i < Dimension; ++i) {
              U[Dimension * r + i] = node_value_list[i_node][i];
            }
          }

          const size_t degree = mesh.degree();

          const Array<const FaceId>& face_list              = bc.faceList();
          const Table<const Rd>& face_inner_node_value_list = bc.faceInnerNodeValueList();

          // treat inner nodes of the faces
          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
              const size_t r = m_face_indices[face_id][r_hat];

              for (size_t i = 0; i < Dimension; ++i) {
                U[Dimension * r + i] = face_inner_node_value_list[i_face][r_hat - 1][i];
              }
            }
          }

        } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
          const Array<const NodeId>& node_list = bc.nodeList();

          // treats vertices of the faces
          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];
            const size_t r       = m_node_index[node_id];
            for (size_t i = 0; i < Dimension; ++i) {
              U[Dimension * r + i] = 0;
            }
          }

          const size_t degree = mesh.degree();

          const Array<const FaceId>& face_list = bc.faceList();

          // treat inner nodes of the faces
          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
              const size_t r = m_face_indices[face_id][r_hat];

              for (size_t i = 0; i < Dimension; ++i) {
                U[Dimension * r + i] = 0;
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
VariationalSolverHandler::VariationalSolver<MeshType>::_applySymmetryBC(const BoundaryConditionList& bc_list,
                                                                        const MeshType& mesh,
                                                                        CRSMatrixDescriptor<double>& A_descriptor,
                                                                        Vector<double>& b) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
          const Rd& n = bc.outgoingNormal();

          const Rdxd I   = identity;
          const Rdxd nxn = tensorProduct(n, n);

          const Rdxd Sn = I - 2 * nxn;

          const Rdxd IpSn = I + Sn;

          const Array<const NodeId>& node_list = bc.nodeList();
          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const size_t r = m_node_index[node_list[i_node]];

            TinyMatrix<Dimension> Arr;
            for (size_t i = 0; i < Dimension; ++i) {
              for (size_t j = 0; j < Dimension; ++j) {
                Arr(i, j) = A_descriptor(Dimension * r + i, Dimension * r + j);
              }
            }

            Arr = Arr + Sn * Arr * Sn;

            for (size_t i = 0; i < Dimension; ++i) {
              for (size_t j = 0; j < Dimension; ++j) {
                A_descriptor(Dimension * r + i, Dimension * r + j) = Arr(i, j);
              }
            }

            TinyVector<Dimension> br;

            for (size_t i = 0; i < Dimension; ++i) {
              br[i] = b[Dimension * r + i];
            }

            br = br + Sn * br;

            for (size_t i = 0; i < Dimension; ++i) {
              b[Dimension * r + i] = br[i];
            }
          }

          const size_t degree = mesh.degree();

          const Array<const FaceId>& face_list = bc.faceList();
          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            for (size_t i_hat = 1; i_hat < degree; ++i_hat) {
              const size_t r = m_face_indices[face_id][i_hat];

              TinyMatrix<Dimension> Arr;
              for (size_t i = 0; i < Dimension; ++i) {
                for (size_t j = 0; j < Dimension; ++j) {
                  Arr(i, j) = A_descriptor(Dimension * r + i, Dimension * r + j);
                }
              }

              Arr = Arr + Sn * Arr * Sn;

              for (size_t i = 0; i < Dimension; ++i) {
                for (size_t j = 0; j < Dimension; ++j) {
                  A_descriptor(Dimension * r + i, Dimension * r + j) = Arr(i, j);
                }
              }

              TinyVector<Dimension> br;

              for (size_t i = 0; i < Dimension; ++i) {
                br[i] = b[Dimension * r + i];
              }

              br = br + Sn * br;
              for (size_t i = 0; i < Dimension; ++i) {
                b[Dimension * r + i] = br[i];
              }
            }
          }

          auto node_to_face_matrix = mesh.connectivity().nodeToFaceMatrix();
          auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];
            auto node_face_list  = node_to_face_matrix[node_id];

            const size_t r = m_node_index[node_list[i_node]];
            for (size_t i_face = 0; i_face < node_face_list.size(); ++i_face) {
              const FaceId face_id = node_face_list[i_face];
              auto face_node_list  = face_to_node_matrix[face_id];

              for (size_t i_hat = 0; i_hat <= degree; ++i_hat) {
                const size_t s = m_face_indices[face_id][i_hat];

                if (s != r) {
                  TinyMatrix<Dimension> Ars;
                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      Ars(i, j) = A_descriptor(Dimension * r + i, Dimension * s + j);
                    }
                  }
                  Ars = IpSn * Ars;

                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      A_descriptor(Dimension * r + i, Dimension * s + j) = Ars(i, j);
                    }
                  }
                }
              }
            }
          }

          for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
            const FaceId face_id = face_list[i_face];

            for (size_t i_hat = 1; i_hat < degree; ++i_hat) {
              const size_t r = m_face_indices[face_id][i_hat];

              for (size_t j_hat = 0; j_hat <= degree; ++j_hat) {
                const size_t s = m_face_indices[face_id][j_hat];

                if (s != r) {
                  TinyMatrix<Dimension> Ars;
                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      Ars(i, j) = A_descriptor(Dimension * r + i, Dimension * s + j);
                    }
                  }
                  Ars = IpSn * Ars;

                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      A_descriptor(Dimension * r + i, Dimension * s + j) = Ars(i, j);
                    }
                  }
                }
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
VariationalSolverHandler::VariationalSolver<MeshType>::_applyVelocityBC(
  const BoundaryConditionList& bc_list,
  const MeshType& mesh,
  const VariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
  CRSMatrixDescriptor<double>& A_descriptor,
  Vector<double>& b) const
{
  switch (velocity_bc_treatment) {
  case VariationalSolverHandler::VelocityBCTreatment::elimination: {
    for (const auto& boundary_condition : bc_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            constexpr TinyMatrix<Dimension> I = identity;

            const Array<const NodeId>& node_list = bc.nodeList();
            auto node_to_face_matrix             = mesh.connectivity().nodeToFaceMatrix();
            const Array<const FaceId>& face_list = bc.faceList();
            const size_t degree                  = mesh.degree();

            const Array<const Rd>& node_value_list            = bc.nodeValueList();
            const Table<const Rd>& face_inner_node_value_list = bc.faceInnerNodeValueList();

            // Update second member for inner dof: treats vertices of the faces
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              const size_t r       = m_node_index[node_id];

              auto node_face_list = node_to_face_matrix[node_id];
              for (size_t i_face = 0; i_face < node_face_list.size(); ++i_face) {
                const FaceId face_id = node_face_list[i_face];

                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r != s) {
                    TinyMatrix<Dimension> Asr;
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        Asr(i, j) = A_descriptor(Dimension * s + i, Dimension * r + j);
                      }
                    }
                    TinyVector<Dimension> AsrU0r = Asr * node_value_list[i_node];

                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * s + i] -= AsrU0r[i];
                    }
                  }
                }
              }
            }

            // Update second member for inner dof: treat inner nodes of the faces
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];

              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t r = m_face_indices[face_id][r_hat];
                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r != s) {
                    TinyMatrix<Dimension> Asr;
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        Asr(i, j) = A_descriptor(Dimension * s + i, Dimension * r + j);
                      }
                    }
                    TinyVector<Dimension> AsrU0r = Asr * face_inner_node_value_list[i_face][r_hat - 1];

                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * s + i] -= AsrU0r[i];
                    }
                  }
                }
              }
            }

            // Update matrix and second member for boundary dof: treats vertices of the faces
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];

              const size_t r = m_node_index[node_id];

              auto node_face_list = node_to_face_matrix[node_id];
              for (size_t i_face = 0; i_face < node_face_list.size(); ++i_face) {
                const FaceId face_id = node_face_list[i_face];

                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r == s) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * r + i] = node_value_list[i_node][i];
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * r + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * s + j) = 0;
                        A_descriptor(Dimension * s + i, Dimension * r + j) = 0;
                      }
                    }
                  }
                }
              }
            }

            // Update matrix and second member for boundary dof: treat inner nodes of the faces
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];

              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t r = m_face_indices[face_id][r_hat];
                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r == s) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * r + i] = face_inner_node_value_list[i_face][r_hat - 1][i];
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * r + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * s + j) = 0;
                        A_descriptor(Dimension * s + i, Dimension * r + j) = 0;
                      }
                    }
                  }
                }
              }
            }
          } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
            constexpr TinyMatrix<Dimension> I = identity;

            const Array<const NodeId>& node_list = bc.nodeList();
            auto node_to_face_matrix             = mesh.connectivity().nodeToFaceMatrix();
            const Array<const FaceId>& face_list = bc.faceList();
            const size_t degree                  = mesh.degree();

            // Treats vertices of the faces
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              const size_t r       = m_node_index[node_id];

              auto node_faces = node_to_face_matrix[node_id];

              for (size_t i_face = 0; i_face < node_faces.size(); ++i_face) {
                const FaceId face_id = node_faces[i_face];
                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r == s) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * r + i] = 0;
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * r + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * s + j) = 0;
                        A_descriptor(Dimension * s + i, Dimension * r + j) = 0;
                      }
                    }
                  }
                }
              }
            }

            // Treat inner nodes of the faces
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];

              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t r = m_face_indices[face_id][r_hat];
                for (size_t s_hat = 0; s_hat <= degree; ++s_hat) {
                  const size_t s = m_face_indices[face_id][s_hat];

                  if (r == s) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * r + i] = 0;
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * r + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * r + i, Dimension * s + j) = 0;
                        A_descriptor(Dimension * s + i, Dimension * r + j) = 0;
                      }
                    }
                  }
                }
              }
            }
          }
        },
        boundary_condition);
    }
    break;
  }
  case VariationalSolverHandler::VelocityBCTreatment::penalty: {
    for (const auto& boundary_condition : bc_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              const size_t r       = m_node_index[node_id];
              A_descriptor(Dimension * r, Dimension * r) += 1.e30;
              A_descriptor(Dimension * r + 1, Dimension * r + 1) += 1.e30;
            }

            const size_t degree                  = mesh.degree();
            const Array<const FaceId>& face_list = bc.faceList();
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];
              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t r = m_face_indices[face_id][r_hat];
                A_descriptor(Dimension * r, Dimension * r) += 1.e30;
                A_descriptor(Dimension * r + 1, Dimension * r + 1) += 1.e30;
              }
            }

            const Array<const Rd>& node_value_list = bc.nodeValueList();
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              const size_t r       = m_node_index[node_id];
              b[Dimension * r] += 1.e30 * node_value_list[i_node][0];
              b[Dimension * r + 1] += 1.e30 * node_value_list[i_node][1];
            }

            const Table<const Rd>& face_inner_node_value_list = bc.faceInnerNodeValueList();

            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];
              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t value_index = r_hat - 1;   // nodal values on the face are not stored in this array

                const size_t r = m_face_indices[face_id][r_hat];
                b[Dimension * r] += 1.e30 * face_inner_node_value_list[i_face][value_index][0];
                b[Dimension * r + 1] += 1.e30 * face_inner_node_value_list[i_face][value_index][1];
              }
            }

          } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
            const Array<const NodeId>& node_list = bc.nodeList();
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              const size_t r       = m_node_index[node_id];
              A_descriptor(Dimension * r, Dimension * r) += 1.e30;
              A_descriptor(Dimension * r + 1, Dimension * r + 1) += 1.e30;
            }

            const size_t degree                  = mesh.degree();
            const Array<const FaceId>& face_list = bc.faceList();
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];
              for (size_t r_hat = 1; r_hat < degree; ++r_hat) {
                const size_t r = m_face_indices[face_id][r_hat];
                A_descriptor(Dimension * r, Dimension * r) += 1.e30;
                A_descriptor(Dimension * r + 1, Dimension * r + 1) += 1.e30;
              }
            }
          }
        },
        boundary_condition);
    }

    break;
  }
  }
}

template <MeshConcept MeshType>
class VariationalSolverHandler::VariationalSolver<MeshType>::FixedBoundaryCondition
{
 private:
  const MeshNodeBoundary m_mesh_node_boundary;
  const MeshFaceBoundary m_mesh_face_boundary;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_face_boundary.faceList();
  }

  FixedBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary, const MeshFaceBoundary& mesh_face_boundary)
    : m_mesh_node_boundary{mesh_node_boundary}, m_mesh_face_boundary{mesh_face_boundary}
  {}

  FixedBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary)
    : m_mesh_node_boundary{mesh_node_boundary}, m_mesh_face_boundary{}
  {}

  ~FixedBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class VariationalSolverHandler::VariationalSolver<MeshType>::PressureBoundaryCondition
{
 private:
  const MeshFaceBoundary m_mesh_face_boundary;
  const Table<const double> m_value_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_face_boundary.faceList();
  }

  const Table<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const MeshFaceBoundary& mesh_face_boundary, const Table<const double>& value_list)
    : m_mesh_face_boundary{mesh_face_boundary}, m_value_list{value_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class VariationalSolverHandler::VariationalSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const MeshNodeBoundary m_mesh_node_boundary;
  const Array<const TinyVector<Dimension>> m_node_value_list;

  const MeshFaceBoundary m_mesh_face_boundary;
  const Table<const TinyVector<Dimension>> m_face_inner_node_value_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  const Array<const TinyVector<Dimension>>&
  nodeValueList() const
  {
    return m_node_value_list;
  }

  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_face_boundary.faceList();
  }

  const Table<const TinyVector<Dimension>>&
  faceInnerNodeValueList() const
  {
    return m_face_inner_node_value_list;
  }

  VelocityBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary,
                            const Array<const TinyVector<Dimension>>& node_value_list)
    : m_mesh_node_boundary{mesh_node_boundary},
      m_node_value_list{node_value_list},
      m_mesh_face_boundary{},
      m_face_inner_node_value_list{}
  {}

  VelocityBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary,
                            const Array<const TinyVector<Dimension>>& node_value_list,
                            const MeshFaceBoundary& mesh_face_boundary,
                            const Table<const TinyVector<Dimension>>& face_inner_node_value_list)
    : m_mesh_node_boundary{mesh_node_boundary},
      m_node_value_list{node_value_list},
      m_mesh_face_boundary{mesh_face_boundary},
      m_face_inner_node_value_list{face_inner_node_value_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class VariationalSolverHandler::VariationalSolver<MeshType>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<MeshType> m_mesh_flat_node_boundary;
  const MeshFlatFaceBoundary<MeshType> m_mesh_flat_face_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_flat_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_flat_face_boundary.faceList();
  }

  SymmetryBoundaryCondition(const MeshFlatNodeBoundary<MeshType>& mesh_flat_node_boundary,
                            const MeshFlatFaceBoundary<MeshType>& mesh_flat_face_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary), m_mesh_flat_face_boundary(mesh_flat_face_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

VariationalSolverHandler::VariationalSolverHandler(const std::shared_ptr<const MeshVariant>& i_mesh)
{
  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr ((is_polynomial_mesh_v<MeshType>)and(MeshType::Dimension == 2)) {
        m_acoustic_solver = std::make_unique<VariationalSolver<MeshType>>();
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    i_mesh->variant());
}