#include <scheme/P1P0AnalyticVariationalSolver.hpp>

#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/ItemValueVariant.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/MeshFlatFaceBoundary.hpp>
#include <mesh/MeshFlatNodeBoundary.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/SubItemValuePerItemVariant.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/ExternalBoundaryConditionDescriptor.hpp>
#include <scheme/FixedBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/Vector.hpp>

#include <variant>
#include <vector>

template <MeshConcept MeshTypeT>
class P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver final
  : public P1P0AnalyticVariationalSolverHandler::IVariationalSolver
{
 private:
  using MeshType     = MeshTypeT;
  using MeshDataType = MeshData<MeshType>;

  static constexpr size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FixedBoundaryCondition;
  class PressureBoundaryCondition;
  class SymmetryBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::
    variant<FixedBoundaryCondition, PressureBoundaryCondition, SymmetryBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList
  _getBCList(const MeshType& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list) const
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        bc_list.emplace_back(
          SymmetryBoundaryCondition(getMeshFlatNodeBoundary(mesh, bc_descriptor->boundaryDescriptor()),
                                    getMeshFlatFaceBoundary(mesh, bc_descriptor->boundaryDescriptor())));
        break;
      }
      case IBoundaryConditionDescriptor::Type::fixed: {
        bc_list.emplace_back(FixedBoundaryCondition(getMeshNodeBoundary(mesh, bc_descriptor->boundaryDescriptor())));
        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const Rd> value_list =
            InterpolateItemValue<Rd(Rd)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                               mesh.xr(),
                                                                               mesh_node_boundary.nodeList());

          bc_list.emplace_back(VelocityBoundaryCondition{mesh_node_boundary, value_list});
        } else if (dirichlet_bc_descriptor.name() == "pressure") {
          const FunctionSymbolId pressure_id = dirichlet_bc_descriptor.rhsSymbolId();

          if constexpr (Dimension == 1) {
            MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(mesh, bc_descriptor->boundaryDescriptor());

            Array<const double> node_values =
              InterpolateItemValue<double(Rd)>::template interpolate<ItemType::node>(pressure_id, mesh.xr(),
                                                                                     mesh_node_boundary.nodeList());

            bc_list.emplace_back(PressureBoundaryCondition{mesh_node_boundary, node_values});
          } else {
            MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(mesh, bc_descriptor->boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(mesh);
            Array<const double> face_values =
              InterpolateItemValue<double(Rd)>::template interpolate<ItemType::face>(pressure_id, mesh_data.xl(),
                                                                                     mesh_face_boundary.faceList());
            bc_list.emplace_back(PressureBoundaryCondition{mesh_face_boundary, face_values});
          }

        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  void _applyPressureBC(const BoundaryConditionList& bc_list, const MeshType& mesh, Vector<double>& b) const;
  void _applySymmetryBC(const BoundaryConditionList& bc_list,
                        const MeshType& mesh,
                        CRSMatrixDescriptor<double>& A,
                        Vector<double>& b) const;
  void _applyVelocityBC(const BoundaryConditionList& bc_list,
                        const MeshType& mesh,
                        const P1P0AnalyticVariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
                        CRSMatrixDescriptor<double>& A,
                        Vector<double>& b) const;

  void
  _applyBoundaryConditions(const BoundaryConditionList& bc_list,
                           const MeshType& mesh,
                           const P1P0AnalyticVariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
                           CRSMatrixDescriptor<double>& A,
                           Vector<double>& b) const
  {
    this->_applyPressureBC(bc_list, mesh, b);
    this->_applySymmetryBC(bc_list, mesh, A, b);
    this->_applyVelocityBC(bc_list, mesh, velocity_bc_treatment, A, b);
  }

  void _forcebcSymmetryBC(const BoundaryConditionList& bc_list, NodeValue<Rd> u_star) const;

  void _forcebcVelocityBC(const BoundaryConditionList& bc_list, NodeValue<Rd> u_star) const;

  void
  _forcebcBoundaryConditions(const BoundaryConditionList& bc_list, NodeValue<Rd> u_star) const
  {
    this->_forcebcSymmetryBC(bc_list, u_star);
    this->_forcebcVelocityBC(bc_list, u_star);
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt,
        const VelocityBCTreatment& velocity_bc_treatment,
        const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& u_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& E_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& a_v,
        const std::shared_ptr<const DiscreteFunctionVariant>& p_v,
        const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list) const
  {
    std::shared_ptr mesh = getCommonMesh({c_v, a_v, rho_v, u_v, E_v, p_v})->get<MeshType>();

    auto xr = mesh->xr();

    DiscreteScalarFunction rho = rho_v->get<DiscreteScalarFunction>();
    DiscreteVectorFunction u   = u_v->get<DiscreteVectorFunction>();
    DiscreteScalarFunction c   = c_v->get<DiscreteScalarFunction>();
    DiscreteScalarFunction a   = a_v->get<DiscreteScalarFunction>();
    DiscreteScalarFunction p   = p_v->get<DiscreteScalarFunction>();
    DiscreteScalarFunction E   = E_v->get<DiscreteScalarFunction>();

    DiscreteScalarFunction rhoc = rho * c;

    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    auto node_to_face_matrix = mesh->connectivity().nodeToFaceMatrix();

    Array<int> non_zero(Dimension * mesh->numberOfNodes());
    for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
      auto face_list                    = node_to_face_matrix[node_id];
      non_zero[node_id * Dimension]     = Dimension * (face_list.size() + 1);
      non_zero[node_id * Dimension + 1] = Dimension * (face_list.size() + 1);
    }

    CRSMatrixDescriptor A_descriptor(Dimension * mesh->numberOfNodes(), Dimension * mesh->numberOfNodes(), non_zero);

    auto face_to_node_matrix = mesh->connectivity().faceToNodeMatrix();
    auto face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();

    auto ll = mesh_data.ll();

    const Rdxd I = identity;

    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      const double face_length = ll[face_id];
      auto cell_list           = face_to_cell_matrix[face_id];

      double Z  = 0;
      double Za = 0;
      for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
        const CellId cell_id = cell_list[i_cell];
        Z += rhoc[cell_id];
        Za += a[cell_id];
      }
      const TinyVector<Dimension> nl = mesh_data.nl()[face_id];
      const TinyMatrix<Dimension> Al = Z * tensorProduct(nl, nl) + Za * (I - tensorProduct(nl, nl));

      for (size_t n0 = 0; n0 < face_to_node_matrix[face_id].size(); ++n0) {
        NodeId r = face_to_node_matrix[face_id][n0];
        for (size_t n1 = 0; n1 < face_to_node_matrix[face_id].size(); ++n1) {
          NodeId s = face_to_node_matrix[face_id][n1];

          TinyMatrix<Dimension> Al_rs = Al;

          if (r == s) {
            Al_rs *= face_length / 3;
          } else {
            Al_rs *= face_length / 6;
          }

          for (size_t i = 0; i < Dimension; ++i) {
            for (size_t j = 0; j < Dimension; ++j) {
              A_descriptor(Dimension * r + i, Dimension * s + j) += Al_rs(i, j);
            }
          }
        }
      }
    }

    const auto& face_local_numbers_in_their_cells = mesh->connectivity().faceLocalNumbersInTheirCells();
    const auto& face_cell_is_reversed             = mesh->connectivity().cellFaceIsReversed();

    Vector<double> b(Dimension * mesh->numberOfNodes());
    b.fill(0);

    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      const double face_length = ll[face_id];
      auto cell_list           = face_to_cell_matrix[face_id];

      for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
        const CellId face_cell_id = cell_list[i_cell];

        const size_t i_face_in_cell = face_local_numbers_in_their_cells[face_id][i_cell];

        const double sign = face_cell_is_reversed(face_cell_id, i_face_in_cell) ? -1 : 1;

        const TinyVector<Dimension> njl = sign * mesh_data.nl()[face_id];
        const TinyMatrix<Dimension> Ajl =
          rhoc[face_cell_id] * tensorProduct(njl, njl) + a[face_cell_id] * (I - tensorProduct(njl, njl));

        for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
          NodeId r = face_to_node_matrix[face_id][i_node];

          TinyVector<Dimension> bl_r = p[face_cell_id] * njl + Ajl * u[face_cell_id];

          bl_r *= face_length / 2;

          for (size_t i = 0; i < Dimension; ++i) {
            b[Dimension * r + i] += bl_r[i];
          }
        }
      }
    }

    const BoundaryConditionList bc_list = this->_getBCList(*mesh, bc_descriptor_list);
    this->_applyBoundaryConditions(bc_list, *mesh, velocity_bc_treatment, A_descriptor, b);

    CRSMatrix A = A_descriptor.getCRSMatrix();

    Vector<double> U{b.size()};
    U = zero;

    LinearSolver solver;
    solver.solveLocalSystem(A, U, b);

    // std:: << "A_descriptor " << A_descriptor << "\n";
    // std::cout << "xr " << xr << "\n";
    // std::cout << "b " << b << "\n";
    // std::exit(0);
    NodeValue<Rd> u_star{mesh->connectivity()};
    for (NodeId node_id = 0; node_id < mesh->numberOfNodes(); ++node_id) {
      for (size_t i = 0; i < Dimension; ++i) {
        u_star[node_id][i] = U[Dimension * node_id + i];
      }
    }
    //  std::cout << "u_star =" << u_star << '\n';
    this->_forcebcBoundaryConditions(bc_list, u_star);

    CellValue<double> new_rho = copy(rho.cellValues());
    CellValue<Rd> new_u       = copy(u.cellValues());
    CellValue<double> new_E   = copy(E.cellValues());

    const auto& cell_to_face_matrix = mesh->connectivity().cellToFaceMatrix();

    CellValue<const double> Vj = mesh_data.Vj();

    for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
      Rd momentum_fluxes   = zero;
      double energy_fluxes = 0;

      auto face_list = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
        const FaceId face_id     = face_list[i_face];
        const double face_length = ll[face_id];

        const double sign = face_cell_is_reversed(cell_id, i_face) ? -1 : 1;

        const TinyVector<Dimension> njl = sign * mesh_data.nl()[face_id];
        const TinyMatrix<Dimension> Ajl =
          rhoc[cell_id] * tensorProduct(njl, njl) + a[cell_id] * (I - tensorProduct(njl, njl));

        auto node_list = face_to_node_matrix[face_id];

        Rd sum_us = zero;
        for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
          const NodeId node_id = node_list[i_node];
          sum_us += u_star[node_id];
        }

        momentum_fluxes += face_length * (Ajl * (u[cell_id] - 0.5 * sum_us));

        energy_fluxes += 0.5 * face_length * dot(p[cell_id] * njl + Ajl * u[cell_id], sum_us);

        for (size_t i_node0 = 0; i_node0 < node_list.size(); ++i_node0) {
          const NodeId node0_id = node_list[i_node0];
          const Rd u0           = u_star[node0_id];

          for (size_t i_node1 = 0; i_node1 < node_list.size(); ++i_node1) {
            const NodeId node1_id = node_list[i_node1];
            const Rd u1           = u_star[node1_id];

            const double delta = (i_node0 == i_node1);
            energy_fluxes -= (1 + delta) * face_length / 6 * dot(Ajl * u0, u1);
          }
        }
      }
      const double dt_over_Mj = dt / (rho[cell_id] * Vj[cell_id]);
      new_u[cell_id] -= dt_over_Mj * momentum_fluxes;
      new_E[cell_id] -= dt_over_Mj * energy_fluxes;
    }

    NodeValue<Rd> new_xr = copy(mesh->xr());
    parallel_for(
      mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * u_star[r]; });

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);

    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*new_mesh).Vj();

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_rho[j] *= Vj[j] / new_Vj[j]; });

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(new_mesh, new_rho)),
            std::make_shared<DiscreteFunctionVariant>(DiscreteVectorFunction(new_mesh, new_u)),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(new_mesh, new_E))};
  }

  P1P0AnalyticVariationalSolver()                                = default;
  P1P0AnalyticVariationalSolver(P1P0AnalyticVariationalSolver&&) = default;
  ~P1P0AnalyticVariationalSolver()                               = default;
};

template <MeshConcept MeshType>
void
P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::_applyPressureBC(
  const BoundaryConditionList& bc_list,
  const MeshType& mesh,
  Vector<double>& b) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
          if constexpr (Dimension > 1) {
            MeshDataType& mesh_data         = MeshDataManager::instance().getMeshData(mesh);
            auto ll                         = mesh_data.ll();
            const auto& face_to_cell_matrix = mesh.connectivity().faceToCellMatrix();
            const auto& face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

            const Array<const FaceId>& face_list = bc.faceList();
            const Array<const double>& p_ext     = bc.valueList();

            const auto& face_local_numbers_in_their_cells = mesh.connectivity().faceLocalNumbersInTheirCells();

            const auto& face_cell_is_reversed = mesh.connectivity().cellFaceIsReversed();

            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              FaceId face_id           = face_list[i_face];
              const double face_length = ll[face_id];
              auto cell_list           = face_to_cell_matrix[face_id];

              for (size_t i_cell = 0; i_cell < cell_list.size(); ++i_cell) {
                const CellId face_cell_id   = cell_list[i_cell];
                const size_t i_face_in_cell = face_local_numbers_in_their_cells[face_id][i_cell];
                const double sign           = face_cell_is_reversed(face_cell_id, i_face_in_cell) ? -1 : 1;

                const TinyVector<Dimension> njl = sign * mesh_data.nl()[face_id];

                for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
                  NodeId r = face_to_node_matrix[face_id][i_node];

                  TinyVector<Dimension> bl_r = p_ext[i_face] * njl;

                  bl_r *= face_length / 2;

                  for (size_t i = 0; i < Dimension; ++i) {
                    b[Dimension * r + i] -= bl_r[i];
                  }
                }
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::_forcebcSymmetryBC(
  const BoundaryConditionList& bc_list,
  NodeValue<Rd> u_star) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
          const Rd& n = bc.outgoingNormal();

          const Rdxd I   = identity;
          const Rdxd nxn = tensorProduct(n, n);
          const Rdxd P   = I - nxn;

          const Array<const NodeId>& node_list = bc.nodeList();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];   // on fixe le sommet r
            u_star[node_id]      = P * u_star[node_id];
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::_forcebcVelocityBC(
  const BoundaryConditionList& bc_list,
  NodeValue<Rd> u_star) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;

        if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
          const Array<const NodeId>& node_list = bc.nodeList();
          const Array<const Rd>& value_list    = bc.valueList();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];
            u_star[node_id]      = value_list[i_node];
          }

        } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
          const Array<const NodeId>& node_list = bc.nodeList();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId node_id = node_list[i_node];
            for (size_t i = 0; i < Dimension; ++i) {
              u_star[node_id] = zero;
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::_applySymmetryBC(
  const BoundaryConditionList& bc_list,
  const MeshType& mesh,
  CRSMatrixDescriptor<double>& A_descriptor,
  Vector<double>& b) const
{
  for (const auto& boundary_condition : bc_list) {
    std::visit(
      [&](auto&& bc) {
        using T = std::decay_t<decltype(bc)>;
        if constexpr (std::is_same_v<SymmetryBoundaryCondition, T>) {
          const Rd& n = bc.outgoingNormal();

          const Rdxd I   = identity;
          const Rdxd nxn = tensorProduct(n, n);

          const Rdxd Sn = I - 2 * nxn;

          const Rdxd IpSn = I + Sn;

          const Array<const NodeId>& node_list = bc.nodeList();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId r = node_list[i_node];
            TinyMatrix<Dimension> Arr;
            for (size_t i = 0; i < Dimension; ++i) {
              for (size_t j = 0; j < Dimension; ++j) {
                Arr(i, j) = A_descriptor(Dimension * r + i, Dimension * r + j);
              }
            }

            Arr = Arr + Sn * Arr * Sn;

            for (size_t i = 0; i < Dimension; ++i) {
              for (size_t j = 0; j < Dimension; ++j) {
                A_descriptor(Dimension * r + i, Dimension * r + j) = Arr(i, j);
              }
            }

            TinyVector<Dimension> br;

            for (size_t i = 0; i < Dimension; ++i) {
              br[i] = b[Dimension * r + i];
            }

            br = br + Sn * br;

            for (size_t i = 0; i < Dimension; ++i) {
              b[Dimension * r + i] = br[i];
            }
          }

          auto node_to_face_matrix = mesh.connectivity().nodeToFaceMatrix();
          auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();

          for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
            const NodeId r = node_list[i_node];
            auto face_list = node_to_face_matrix[r];
            for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
              const FaceId face_id = face_list[i_face];
              auto face_node_list  = face_to_node_matrix[face_id];

              for (size_t j_node = 0; j_node < face_node_list.size(); ++j_node) {
                const NodeId s = face_node_list[j_node];
                if (s != r) {
                  TinyMatrix<Dimension> Ars;
                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      Ars(i, j) = A_descriptor(Dimension * r + i, Dimension * s + j);
                    }
                  }
                  Ars = IpSn * Ars;

                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      A_descriptor(Dimension * r + i, Dimension * s + j) = Ars(i, j);
                    }
                  }
                }
              }
            }
          }
        }
      },
      boundary_condition);
  }
}

template <MeshConcept MeshType>
void
P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::_applyVelocityBC(
  const BoundaryConditionList& bc_list,
  const MeshType& mesh,
  const P1P0AnalyticVariationalSolverHandler::VelocityBCTreatment& velocity_bc_treatment,
  CRSMatrixDescriptor<double>& A_descriptor,
  Vector<double>& b) const
{
  switch (velocity_bc_treatment) {
  case P1P0AnalyticVariationalSolverHandler::VelocityBCTreatment::elimination: {
    for (const auto& boundary_condition : bc_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            constexpr TinyMatrix<Dimension> I = identity;

            const Array<const NodeId>& node_list = bc.nodeList();
            const Array<const Rd>& value_list    = bc.valueList();

            for (size_t i_r_node = 0; i_r_node < node_list.size(); ++i_r_node) {
              const NodeId node_r_id = node_list[i_r_node];

              auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();
              auto node_to_face_matrix = mesh.connectivity().nodeToFaceMatrix();
              auto face_list           = node_to_face_matrix[node_r_id];

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];

                for (size_t i_s_node = 0; i_s_node < face_to_node_matrix[face_id].size(); ++i_s_node) {
                  NodeId node_s_id = face_to_node_matrix[face_id][i_s_node];

                  if (node_r_id != node_s_id) {
                    TinyMatrix<Dimension> Asr;
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        Asr(i, j) = A_descriptor(Dimension * node_s_id + i, Dimension * node_r_id + j);
                      }
                    }
                    TinyVector<Dimension> AsrU0r = Asr * value_list[i_r_node];

                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * node_s_id + i] -= AsrU0r[i];
                    }
                  }
                }
              }
            }
            for (size_t i_r_node = 0; i_r_node < node_list.size(); ++i_r_node) {
              const NodeId node_r_id = node_list[i_r_node];

              auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();
              auto node_to_face_matrix = mesh.connectivity().nodeToFaceMatrix();
              auto face_list           = node_to_face_matrix[node_r_id];

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];

                for (size_t i_s_node = 0; i_s_node < face_to_node_matrix[face_id].size(); ++i_s_node) {
                  NodeId node_s_id = face_to_node_matrix[face_id][i_s_node];

                  if (node_r_id == node_s_id) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * node_r_id + i] = value_list[i_r_node][i];
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * node_r_id + i, Dimension * node_r_id + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * node_r_id + i, Dimension * node_s_id + j) = 0;
                        A_descriptor(Dimension * node_s_id + i, Dimension * node_r_id + j) = 0;
                      }
                    }
                  }
                }
              }
            }

          } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
            constexpr TinyMatrix<Dimension> I = identity;

            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_r_node = 0; i_r_node < node_list.size(); ++i_r_node) {
              const NodeId node_r_id = node_list[i_r_node];

              auto face_to_node_matrix = mesh.connectivity().faceToNodeMatrix();
              auto node_to_face_matrix = mesh.connectivity().nodeToFaceMatrix();
              auto face_list           = node_to_face_matrix[node_r_id];

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];

                for (size_t i_s_node = 0; i_s_node < face_to_node_matrix[face_id].size(); ++i_s_node) {
                  NodeId node_s_id = face_to_node_matrix[face_id][i_s_node];

                  if (node_r_id == node_s_id) {
                    for (size_t i = 0; i < Dimension; ++i) {
                      b[Dimension * node_r_id + i] = 0;
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * node_r_id + i, Dimension * node_r_id + j) = I(i, j);
                      }
                    }
                  } else {
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        A_descriptor(Dimension * node_r_id + i, Dimension * node_s_id + j) = 0;
                        A_descriptor(Dimension * node_s_id + i, Dimension * node_r_id + j) = 0;
                      }
                    }
                  }
                }
              }
            }
          }
        },
        boundary_condition);
    }
    break;
  }
  case P1P0AnalyticVariationalSolverHandler::VelocityBCTreatment::penalty: {
    for (const auto& boundary_condition : bc_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const Array<const NodeId>& node_list = bc.nodeList();
            const Array<const Rd>& value_list    = bc.valueList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              A_descriptor(Dimension * node_id, Dimension * node_id) += 1.e30;
              A_descriptor(Dimension * node_id + 1, Dimension * node_id + 1) += 1.e30;
            }

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              b[Dimension * node_id] += 1.e30 * value_list[i_node][0];
              b[Dimension * node_id + 1] += 1.e30 * value_list[i_node][1];
            }

          } else if constexpr (std::is_same_v<FixedBoundaryCondition, T>) {
            const Array<const NodeId>& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId node_id = node_list[i_node];
              A_descriptor(Dimension * node_id, Dimension * node_id) += 1.e30;
              A_descriptor(Dimension * node_id + 1, Dimension * node_id + 1) += 1.e30;
            }
          }
        },
        boundary_condition);
    }
    break;
  }
  }
}

template <MeshConcept MeshType>
class P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::FixedBoundaryCondition
{
 private:
  const MeshNodeBoundary m_mesh_node_boundary;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  FixedBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary) : m_mesh_node_boundary{mesh_node_boundary} {}

  ~FixedBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::PressureBoundaryCondition
{
 private:
  const MeshFaceBoundary m_mesh_face_boundary;
  const Array<const double> m_value_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_face_boundary.faceList();
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  PressureBoundaryCondition(const MeshFaceBoundary& mesh_face_boundary, const Array<const double>& value_list)
    : m_mesh_face_boundary{mesh_face_boundary}, m_value_list{value_list}
  {}

  ~PressureBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const MeshNodeBoundary m_mesh_node_boundary;

  const Array<const TinyVector<Dimension>> m_value_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const MeshNodeBoundary& mesh_node_boundary,
                            const Array<const TinyVector<Dimension>>& value_list)
    : m_mesh_node_boundary{mesh_node_boundary}, m_value_list{value_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolver<MeshType>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<MeshType> m_mesh_flat_node_boundary;
  const MeshFlatFaceBoundary<MeshType> m_mesh_flat_face_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  size_t
  numberOfNodes() const
  {
    return m_mesh_flat_node_boundary.nodeList().size();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  const Array<const FaceId>&
  faceList() const
  {
    return m_mesh_flat_face_boundary.faceList();
  }

  SymmetryBoundaryCondition(const MeshFlatNodeBoundary<MeshType>& mesh_flat_node_boundary,
                            const MeshFlatFaceBoundary<MeshType>& mesh_flat_face_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary), m_mesh_flat_face_boundary(mesh_flat_face_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

P1P0AnalyticVariationalSolverHandler::P1P0AnalyticVariationalSolverHandler(
  const std::shared_ptr<const MeshVariant>& i_mesh)
{
  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr ((is_polygonal_mesh_v<MeshType>)and(MeshType::Dimension == 2)) {
        m_acoustic_solver = std::make_unique<P1P0AnalyticVariationalSolver<MeshType>>();
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    i_mesh->variant());
}
