#include <language/algorithms/ElasticityDiamondAlgorithm.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LeastSquareSolver.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/SmallMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/DualConnectivityManager.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/PrimalToDiamondDualConnectivityDataMapper.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/NeumannBoundaryConditionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

template <size_t Dimension>
ElasticityDiamondScheme<Dimension>::ElasticityDiamondScheme(
  std::shared_ptr<const IMesh> i_mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& lambda_id,
  const FunctionSymbolId& mu_id,
  const FunctionSymbolId& f_id,
  const FunctionSymbolId& U_id)
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  using BoundaryCondition =
    std::variant<DirichletBoundaryCondition, NormalStrainBoundaryCondition, SymmetryBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList boundary_condition_list;

  std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';
  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

  NodeValue<bool> is_dirichlet{mesh->connectivity()};
  is_dirichlet.fill(false);
  NodeValue<TinyVector<Dimension>> dirichlet_value{mesh->connectivity()};
  {
    TinyVector<Dimension> nan_tiny_vector;
    for (size_t i = 0; i < Dimension; ++i) {
      nan_tiny_vector[i] = std::numeric_limits<double>::signaling_NaN();
    }
    dirichlet_value.fill(nan_tiny_vector);
  }

  for (const auto& bc_descriptor : bc_descriptor_list) {
    bool is_valid_boundary_condition = true;

    switch (bc_descriptor->type()) {
    case IBoundaryConditionDescriptor::Type::symmetry: {
      const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
        dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
      if constexpr (Dimension > 1) {
        MeshFaceBoundary mesh_face_boundary = getMeshFaceBoundary(*mesh, sym_bc_descriptor.boundaryDescriptor());
        boundary_condition_list.push_back(SymmetryBoundaryCondition{mesh_face_boundary.faceList()});

      } else {
        throw NotImplementedError("Symmetry conditions are not supported in 1d");
      }

      break;
    }
    case IBoundaryConditionDescriptor::Type::dirichlet: {
      const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
        dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
      if (dirichlet_bc_descriptor.name() == "dirichlet") {
        if constexpr (Dimension > 1) {
          MeshFaceBoundary mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());
          MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                          mesh_face_boundary.faceList());

          boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});
        } else {
          throw NotImplementedError("Dirichlet conditions are not supported in 1d");
        }
      } else if (dirichlet_bc_descriptor.name() == "normal_strain") {
        if constexpr (Dimension > 1) {
          MeshFaceBoundary mesh_face_boundary =
            getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());
          MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

          const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                          mesh_face_boundary.faceList());
          boundary_condition_list.push_back(NormalStrainBoundaryCondition{mesh_face_boundary.faceList(), value_list});
        } else {
          throw NotImplementedError("Normal strain conditions are not supported in 1d");
        }

      } else {
        is_valid_boundary_condition = false;
      }
      break;
    }
    default: {
      is_valid_boundary_condition = false;
    }
    }
    if (not is_valid_boundary_condition) {
      std::ostringstream error_msg;
      error_msg << *bc_descriptor << " is an invalid boundary condition for elasticity equation";
      throw NormalError(error_msg.str());
    }
  }

  if constexpr (Dimension > 1) {
    const CellValue<const size_t> cell_dof_number = [&] {
      CellValue<size_t> compute_cell_dof_number{mesh->connectivity()};
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { compute_cell_dof_number[cell_id] = cell_id; });
      return compute_cell_dof_number;
    }();
    size_t number_of_dof = mesh->numberOfCells();

    const FaceValue<const size_t> face_dof_number = [&] {
      FaceValue<size_t> compute_face_dof_number{mesh->connectivity()};
      compute_face_dof_number.fill(std::numeric_limits<size_t>::max());
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>) or
                          (std::is_same_v<T, SymmetryBoundaryCondition>) or
                          (std::is_same_v<T, DirichletBoundaryCondition>)) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];
                if (compute_face_dof_number[face_id] != std::numeric_limits<size_t>::max()) {
                  std::ostringstream os;
                  os << "The face " << face_id << " is used at least twice for boundary conditions";
                  throw NormalError(os.str());
                } else {
                  compute_face_dof_number[face_id] = number_of_dof++;
                }
              }
            }
          },
          boundary_condition);
      }

      return compute_face_dof_number;
    }();

    const auto& primal_face_to_node_matrix             = mesh->connectivity().faceToNodeMatrix();
    const auto& face_to_cell_matrix                    = mesh->connectivity().faceToCellMatrix();
    const FaceValue<const bool> primal_face_is_neumann = [&] {
      FaceValue<bool> face_is_neumann{mesh->connectivity()};
      face_is_neumann.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id     = face_list[i_face];
                face_is_neumann[face_id] = true;
              }
            }
          },
          boundary_condition);
      }

      return face_is_neumann;
    }();

    const FaceValue<const bool> primal_face_is_symmetry = [&] {
      FaceValue<bool> face_is_symmetry{mesh->connectivity()};
      face_is_symmetry.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, SymmetryBoundaryCondition>)) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id      = face_list[i_face];
                face_is_symmetry[face_id] = true;
              }
            }
          },
          boundary_condition);
      }

      return face_is_symmetry;
    }();

    NodeValue<bool> primal_node_is_on_boundary(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of node_is_on_boundary is incorrect");
    }

    primal_node_is_on_boundary.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      if (face_to_cell_matrix[face_id].size() == 1) {
        for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
          NodeId node_id                      = primal_face_to_node_matrix[face_id][i_node];
          primal_node_is_on_boundary[node_id] = true;
        }
      }
    }

    FaceValue<bool> primal_face_is_on_boundary(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of face_is_on_boundary is incorrect");
    }

    primal_face_is_on_boundary.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      if (face_to_cell_matrix[face_id].size() == 1) {
        primal_face_is_on_boundary[face_id] = true;
      }
    }

    FaceValue<bool> primal_face_is_dirichlet(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of face_is_neumann is incorrect");
    }

    primal_face_is_dirichlet.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      primal_face_is_dirichlet[face_id] = (primal_face_is_on_boundary[face_id] && (!primal_face_is_neumann[face_id]) &&
                                           (!primal_face_is_symmetry[face_id]));
    }
    MeshDataType& mesh_data                          = MeshDataManager::instance().getMeshData(*mesh);
    const NodeValue<const TinyVector<Dimension>>& xr = mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
    const auto& node_to_cell_matrix                  = mesh->connectivity().nodeToCellMatrix();
    const auto& node_to_face_matrix                  = mesh->connectivity().nodeToFaceMatrix();
    CellValuePerNode<double> w_rj{mesh->connectivity()};
    FaceValuePerNode<double> w_rl{mesh->connectivity()};

    const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();
    auto project_to_face = [&](const TinyVector<Dimension>& x, const FaceId face_id) -> const TinyVector<Dimension> {
      TinyVector<Dimension> proj;
      const TinyVector<Dimension> nil = primal_nlr(face_id, 0);
      proj                            = x - dot((x - xl[face_id]), nil) * nil;
      return proj;
    };

    for (size_t i = 0; i < w_rl.numberOfValues(); ++i) {
      w_rl[i] = std::numeric_limits<double>::signaling_NaN();
    }

    for (NodeId i_node = 0; i_node < mesh->numberOfNodes(); ++i_node) {
      SmallVector<double> b{Dimension + 1};
      b[0] = 1;
      for (size_t i = 1; i < Dimension + 1; i++) {
        b[i] = xr[i_node][i - 1];
      }
      const auto& node_to_cell = node_to_cell_matrix[i_node];

      if (not primal_node_is_on_boundary[i_node]) {
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
        for (size_t j = 0; j < node_to_cell.size(); j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }

        SmallVector<double> x{node_to_cell.size()};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
      } else {
        int nb_face_used = 0;
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (primal_face_is_on_boundary[face_id]) {
            nb_face_used++;
          }
        }
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size() + nb_face_used};
        for (size_t j = 0; j < node_to_cell.size() + nb_face_used; j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          int cpt_face = 0;
          for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[i_node][i_face];
            if (primal_face_is_on_boundary[face_id]) {
              if (primal_face_is_symmetry[face_id]) {
                for (size_t j = 0; j < face_to_cell_matrix[face_id].size(); ++j) {
                  const CellId cell_id                 = face_to_cell_matrix[face_id][j];
                  TinyVector<Dimension> xproj          = project_to_face(xj[cell_id], face_id);
                  A(i, node_to_cell.size() + cpt_face) = xproj[i - 1];
                }
              } else {
                A(i, node_to_cell.size() + cpt_face) = xl[face_id][i - 1];
              }
              cpt_face++;
            }
          }
        }

        SmallVector<double> x{node_to_cell.size() + nb_face_used};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
        int cpt_face = node_to_cell.size();
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (primal_face_is_on_boundary[face_id]) {
            w_rl(i_node, i_face) = x[cpt_face++];
          }
        }
      }
    }

    {
      std::shared_ptr diamond_mesh = DualMeshManager::instance().getDiamondDualMesh(*mesh);

      MeshDataType& diamond_mesh_data = MeshDataManager::instance().getMeshData(*diamond_mesh);

      std::shared_ptr mapper =
        DualConnectivityManager::instance().getPrimalToDiamondDualConnectivityDataMapper(mesh->connectivity());

      CellValue<double> dual_muj =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(mu_id,
                                                                                                  diamond_mesh_data
                                                                                                    .xj());

      CellValue<double> dual_lambdaj =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(lambda_id,
                                                                                                  diamond_mesh_data
                                                                                                    .xj());

      CellValue<TinyVector<Dimension>> Uj = InterpolateItemValue<TinyVector<Dimension>(
        TinyVector<Dimension>)>::template interpolate<ItemType::cell>(U_id, mesh_data.xj());

      CellValue<TinyVector<Dimension>> fj = InterpolateItemValue<TinyVector<Dimension>(
        TinyVector<Dimension>)>::template interpolate<ItemType::cell>(f_id, mesh_data.xj());

      const CellValue<const double> dual_Vj = diamond_mesh_data.Vj();

      const FaceValue<const double> mes_l = [&] {
        if constexpr (Dimension == 1) {
          FaceValue<double> compute_mes_l{mesh->connectivity()};
          compute_mes_l.fill(1);
          return compute_mes_l;
        } else {
          return mesh_data.ll();
        }
      }();

      const CellValue<const double> dual_mes_l_j = [=] {
        CellValue<double> compute_mes_j{diamond_mesh->connectivity()};
        mapper->toDualCell(mes_l, compute_mes_j);

        return compute_mes_j;
      }();

      const CellValue<const double> primal_Vj   = mesh_data.Vj();
      FaceValue<const CellId> face_dual_cell_id = [=]() {
        FaceValue<CellId> computed_face_dual_cell_id{mesh->connectivity()};
        CellValue<CellId> dual_cell_id{diamond_mesh->connectivity()};
        parallel_for(
          diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { dual_cell_id[cell_id] = cell_id; });

        mapper->fromDualCell(dual_cell_id, computed_face_dual_cell_id);

        return computed_face_dual_cell_id;
      }();

      NodeValue<const NodeId> dual_node_primal_node_id = [=]() {
        CellValue<NodeId> cell_ignored_id{mesh->connectivity()};
        cell_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

        NodeValue<NodeId> node_primal_id{mesh->connectivity()};

        parallel_for(
          mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { node_primal_id[node_id] = node_id; });

        NodeValue<NodeId> computed_dual_node_primal_node_id{diamond_mesh->connectivity()};

        mapper->toDualNode(node_primal_id, cell_ignored_id, computed_dual_node_primal_node_id);

        return computed_dual_node_primal_node_id;
      }();

      CellValue<NodeId> primal_cell_dual_node_id = [=]() {
        CellValue<NodeId> cell_id{mesh->connectivity()};
        NodeValue<NodeId> node_ignored_id{mesh->connectivity()};
        node_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

        NodeValue<NodeId> dual_node_id{diamond_mesh->connectivity()};

        parallel_for(
          diamond_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { dual_node_id[node_id] = node_id; });

        CellValue<NodeId> computed_primal_cell_dual_node_id{mesh->connectivity()};

        mapper->fromDualNode(dual_node_id, node_ignored_id, cell_id);

        return cell_id;
      }();
      const auto& dual_Cjr                     = diamond_mesh_data.Cjr();
      FaceValue<TinyVector<Dimension>> dualClj = [&] {
        FaceValue<TinyVector<Dimension>> computedClj{mesh->connectivity()};
        const auto& dual_node_to_cell_matrix = diamond_mesh->connectivity().nodeToCellMatrix();
        const auto& dual_cell_to_node_matrix = diamond_mesh->connectivity().cellToNodeMatrix();
        parallel_for(
          mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
            const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
            for (size_t i = 0; i < primal_face_to_cell.size(); i++) {
              CellId cell_id            = primal_face_to_cell[i];
              const NodeId dual_node_id = primal_cell_dual_node_id[cell_id];
              for (size_t i_dual_cell = 0; i_dual_cell < dual_node_to_cell_matrix[dual_node_id].size(); i_dual_cell++) {
                const CellId dual_cell_id = dual_node_to_cell_matrix[dual_node_id][i_dual_cell];
                if (face_dual_cell_id[face_id] == dual_cell_id) {
                  for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size();
                       i_dual_node++) {
                    const NodeId final_dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
                    if (final_dual_node_id == dual_node_id) {
                      computedClj[face_id] = dual_Cjr(dual_cell_id, i_dual_node);
                    }
                  }
                }
              }
            }
          });
        return computedClj;
      }();

      FaceValue<TinyVector<Dimension>> nlj = [&] {
        FaceValue<TinyVector<Dimension>> computedNlj{mesh->connectivity()};
        parallel_for(
          mesh->numberOfFaces(),
          PUGS_LAMBDA(FaceId face_id) { computedNlj[face_id] = 1. / l2Norm(dualClj[face_id]) * dualClj[face_id]; });
        return computedNlj;
      }();

      FaceValue<const double> alpha_lambda_l = [&] {
        CellValue<double> alpha_j{diamond_mesh->connectivity()};

        parallel_for(
          diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
            alpha_j[diamond_cell_id] = dual_lambdaj[diamond_cell_id] / dual_Vj[diamond_cell_id];
          });

        FaceValue<double> computed_alpha_l{mesh->connectivity()};
        mapper->fromDualCell(alpha_j, computed_alpha_l);
        return computed_alpha_l;
      }();

      FaceValue<const double> alpha_mu_l = [&] {
        CellValue<double> alpha_j{diamond_mesh->connectivity()};

        parallel_for(
          diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
            alpha_j[diamond_cell_id] = dual_muj[diamond_cell_id] / dual_Vj[diamond_cell_id];
          });

        FaceValue<double> computed_alpha_l{mesh->connectivity()};
        mapper->fromDualCell(alpha_j, computed_alpha_l);
        return computed_alpha_l;
      }();

      const TinyMatrix<Dimension> I = identity;

      const Array<int> non_zeros{number_of_dof * Dimension};
      non_zeros.fill(Dimension * Dimension);
      CRSMatrixDescriptor<double> S(number_of_dof * Dimension, number_of_dof * Dimension, non_zeros);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        const double beta_mu_l          = l2Norm(dualClj[face_id]) * alpha_mu_l[face_id] * mes_l[face_id];
        const double beta_lambda_l      = l2Norm(dualClj[face_id]) * alpha_lambda_l[face_id] * mes_l[face_id];
        const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
        for (size_t i_cell = 0; i_cell < primal_face_to_cell.size(); ++i_cell) {
          const CellId i_id                      = primal_face_to_cell[i_cell];
          const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

          const TinyVector<Dimension> nil = [&] {
            if (is_face_reversed_for_cell_i) {
              return -nlj[face_id];
            } else {
              return nlj[face_id];
            }
          }();
          for (size_t j_cell = 0; j_cell < primal_face_to_cell.size(); ++j_cell) {
            const CellId j_id = primal_face_to_cell[j_cell];
            TinyMatrix<Dimension> M =
              beta_mu_l * I + beta_mu_l * tensorProduct(nil, nil) + beta_lambda_l * tensorProduct(nil, nil);
            TinyMatrix<Dimension> N = tensorProduct(nil, nil);

            if (i_cell == j_cell) {
              for (size_t i = 0; i < Dimension; ++i) {
                for (size_t j = 0; j < Dimension; ++j) {
                  S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) += M(i, j);
                  if (primal_face_is_neumann[face_id]) {
                    S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) -= M(i, j);
                  }
                  if (primal_face_is_symmetry[face_id]) {
                    S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) +=
                      -((i == j) ? 1 : 0) + N(i, j);
                    S(face_dof_number[face_id] * Dimension + i, face_dof_number[face_id] * Dimension + j) +=
                      (i == j) ? 1 : 0;
                  }
                }
              }
            } else {
              for (size_t i = 0; i < Dimension; ++i) {
                for (size_t j = 0; j < Dimension; ++j) {
                  S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) -= M(i, j);
                }
              }
            }
          }
        }
      }

      const auto& dual_cell_to_node_matrix   = diamond_mesh->connectivity().cellToNodeMatrix();
      const auto& primal_node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        const double alpha_mu_face_id     = mes_l[face_id] * alpha_mu_l[face_id];
        const double alpha_lambda_face_id = mes_l[face_id] * alpha_lambda_l[face_id];

        for (size_t i_face_cell = 0; i_face_cell < face_to_cell_matrix[face_id].size(); ++i_face_cell) {
          CellId i_id                            = face_to_cell_matrix[face_id][i_face_cell];
          const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

          for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
            NodeId node_id = primal_face_to_node_matrix[face_id][i_node];

            const TinyVector<Dimension> nil = [&] {
              if (is_face_reversed_for_cell_i) {
                return -nlj[face_id];
              } else {
                return nlj[face_id];
              }
            }();

            CellId dual_cell_id = face_dual_cell_id[face_id];

            for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size(); ++i_dual_node) {
              const NodeId dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
              if (dual_node_primal_node_id[dual_node_id] == node_id) {
                const TinyVector<Dimension> Clr = dual_Cjr(dual_cell_id, i_dual_node);

                TinyMatrix<Dimension> M = alpha_mu_face_id * dot(Clr, nil) * I +
                                          alpha_mu_face_id * tensorProduct(Clr, nil) +
                                          alpha_lambda_face_id * tensorProduct(nil, Clr);

                for (size_t j_cell = 0; j_cell < primal_node_to_cell_matrix[node_id].size(); ++j_cell) {
                  CellId j_id = primal_node_to_cell_matrix[node_id][j_cell];
                  for (size_t i = 0; i < Dimension; ++i) {
                    for (size_t j = 0; j < Dimension; ++j) {
                      S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) -=
                        w_rj(node_id, j_cell) * M(i, j);
                      if (primal_face_is_neumann[face_id]) {
                        S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) +=
                          w_rj(node_id, j_cell) * M(i, j);
                      }
                    }
                  }
                }
                if (primal_node_is_on_boundary[node_id]) {
                  for (size_t l_face = 0; l_face < node_to_face_matrix[node_id].size(); ++l_face) {
                    FaceId l_id = node_to_face_matrix[node_id][l_face];
                    if (primal_face_is_on_boundary[l_id]) {
                      for (size_t i = 0; i < Dimension; ++i) {
                        for (size_t j = 0; j < Dimension; ++j) {
                          S(cell_dof_number[i_id] * Dimension + i, face_dof_number[l_id] * Dimension + j) -=
                            w_rl(node_id, l_face) * M(i, j);
                        }
                      }
                      if (primal_face_is_neumann[face_id]) {
                        for (size_t i = 0; i < Dimension; ++i) {
                          for (size_t j = 0; j < Dimension; ++j) {
                            S(face_dof_number[face_id] * Dimension + i, face_dof_number[l_id] * Dimension + j) +=
                              w_rl(node_id, l_face) * M(i, j);
                          }
                        }
                      }
                    }
                  }
                }
              }
            }
            //            }
          }
        }
      }
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (primal_face_is_dirichlet[face_id]) {
          for (size_t i = 0; i < Dimension; ++i) {
            S(face_dof_number[face_id] * Dimension + i, face_dof_number[face_id] * Dimension + i) += 1;
          }
        }
      }

      CRSMatrix A{S.getCRSMatrix()};
      Vector<double> b{number_of_dof * Dimension};
      b = zero;
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        for (size_t i = 0; i < Dimension; ++i) {
          b[(cell_dof_number[cell_id] * Dimension) + i] = primal_Vj[cell_id] * fj[cell_id][i];
        }
      }

      // Dirichlet
      NodeValue<bool> node_tag{mesh->connectivity()};
      node_tag.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];

                for (size_t i = 0; i < Dimension; ++i) {
                  b[(face_dof_number[face_id] * Dimension) + i] += value_list[i_face][i];
                }
              }
            }
          },
          boundary_condition);
      }

      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
              const auto& face_list  = bc.faceList();
              const auto& value_list = bc.valueList();
              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                FaceId face_id = face_list[i_face];
                for (size_t i = 0; i < Dimension; ++i) {
                  b[face_dof_number[face_id] * Dimension + i] += mes_l[face_id] * value_list[i_face][i];   // sign
                }
              }
            }
          },
          boundary_condition);
      }

      Vector<double> U{number_of_dof * Dimension};
      U = zero;
      CellValue<TinyVector<Dimension>> Speed{mesh->connectivity()};
      FaceValue<TinyVector<Dimension>> Ul = InterpolateItemValue<TinyVector<Dimension>(
        TinyVector<Dimension>)>::template interpolate<ItemType::face>(U_id, mesh_data.xl());
      FaceValue<TinyVector<Dimension>> Speed_face{mesh->connectivity()};

      Vector r = A * U - b;
      std::cout << "initial (real) residu = " << std::sqrt(dot(r, r)) << '\n';

      LinearSolver solver;
      solver.solveLocalSystem(A, U, b);

      r = A * U - b;

      std::cout << "final (real) residu = " << std::sqrt(dot(r, r)) << '\n';

      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        for (size_t i = 0; i < Dimension; ++i) {
          Speed[cell_id][i] = U[(cell_dof_number[cell_id] * Dimension) + i];
        }
      }
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        for (size_t i = 0; i < Dimension; ++i) {
          if (primal_face_is_on_boundary[face_id]) {
            Speed_face[face_id][i] = U[(face_dof_number[face_id] * Dimension) + i];
          } else {
            Speed_face[face_id][i] = Ul[face_id][i];
          }
        }
      }
      Vector<double> Uexacte{mesh->numberOfCells() * Dimension};
      for (CellId j = 0; j < mesh->numberOfCells(); ++j) {
        for (size_t l = 0; l < Dimension; ++l) {
          Uexacte[(cell_dof_number[j] * Dimension) + l] = Uj[j][l];
        }
      }

      Vector<double> error{mesh->numberOfCells() * Dimension};
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        for (size_t i = 0; i < Dimension; ++i) {
          error[(cell_id * Dimension) + i] = (Speed[cell_id][i] - Uj[cell_id][i]) * sqrt(primal_Vj[cell_id]);
        }
      }
      Vector<double> error_face{mesh->numberOfFaces() * Dimension};
      parallel_for(
        mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
          if (primal_face_is_on_boundary[face_id]) {
            for (size_t i = 0; i < Dimension; ++i) {
              error_face[face_id * Dimension + i] = (Speed_face[face_id][i] - Ul[face_id][i]) * sqrt(mes_l[face_id]);
            }
          } else {
            error_face[face_id] = 0;
          }
        });

      std::cout << "||Error||_2 (cell)= " << std::sqrt(dot(error, error)) << "\n";
      std::cout << "||Error||_2 (face)= " << std::sqrt(dot(error_face, error_face)) << "\n";
      std::cout << "||Error||_2 (total)= " << std::sqrt(dot(error, error)) + std::sqrt(dot(error_face, error_face))
                << "\n";

      NodeValue<TinyVector<3>> ur3d{mesh->connectivity()};
      ur3d.fill(zero);

      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          TinyVector<Dimension> x = zero;
          const auto node_cells   = node_to_cell_matrix[node_id];
          for (size_t i_cell = 0; i_cell < node_cells.size(); ++i_cell) {
            CellId cell_id = node_cells[i_cell];
            x += w_rj(node_id, i_cell) * Speed[cell_id];
          }
          const auto node_faces = node_to_face_matrix[node_id];
          for (size_t i_face = 0; i_face < node_faces.size(); ++i_face) {
            FaceId face_id = node_faces[i_face];
            if (primal_face_is_on_boundary[face_id]) {
              x += w_rl(node_id, i_face) * Speed_face[face_id];
            }
          }
          for (size_t i = 0; i < Dimension; ++i) {
            ur3d[node_id][i] = x[i];
          }
        });
    }
  } else {
    throw NotImplementedError("not done in 1d");
  }
}

template <size_t Dimension>
class ElasticityDiamondScheme<Dimension>::DirichletBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  DirichletBoundaryCondition(const Array<const FaceId>& face_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {
    Assert(m_value_list.size() == m_face_list.size());
  }

  ~DirichletBoundaryCondition() = default;
};

template <size_t Dimension>
class ElasticityDiamondScheme<Dimension>::NormalStrainBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  NormalStrainBoundaryCondition(const Array<const FaceId>& face_list,
                                const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {
    Assert(m_value_list.size() == m_face_list.size());
  }

  ~NormalStrainBoundaryCondition() = default;
};

template <size_t Dimension>
class ElasticityDiamondScheme<Dimension>::SymmetryBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

 public:
  SymmetryBoundaryCondition(const Array<const FaceId>& face_list) : m_face_list{face_list} {}

  ~SymmetryBoundaryCondition() = default;
};

template ElasticityDiamondScheme<1>::ElasticityDiamondScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template ElasticityDiamondScheme<2>::ElasticityDiamondScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template ElasticityDiamondScheme<3>::ElasticityDiamondScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);