#include <scheme/VectorDiamondScheme.hpp>

#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

template <size_t Dimension>
class VectorDiamondSchemeHandler::InterpolationWeightsManager
{
 private:
  std::shared_ptr<const Mesh<Connectivity<Dimension>>> m_mesh;
  FaceValue<const bool> m_primal_face_is_on_boundary;
  NodeValue<const bool> m_primal_node_is_on_boundary;
  FaceValue<const bool> m_primal_face_is_symmetry;
  CellValuePerNode<double> m_w_rj;
  FaceValuePerNode<double> m_w_rl;

 public:
  InterpolationWeightsManager(std::shared_ptr<const Mesh<Connectivity<Dimension>>> mesh,
                              FaceValue<const bool> primal_face_is_on_boundary,
                              NodeValue<const bool> primal_node_is_on_boundary,
                              FaceValue<const bool> primal_face_is_symmetry)
    : m_mesh(mesh),
      m_primal_face_is_on_boundary(primal_face_is_on_boundary),
      m_primal_node_is_on_boundary(primal_node_is_on_boundary),
      m_primal_face_is_symmetry(primal_face_is_symmetry)
  {}
  ~InterpolationWeightsManager() = default;
  CellValuePerNode<double>&
  wrj()
  {
    return m_w_rj;
  }
  FaceValuePerNode<double>&
  wrl()
  {
    return m_w_rl;
  }
  void
  compute()
  {
    using MeshDataType      = MeshData<Dimension>;
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = m_mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
    const auto& node_to_cell_matrix                  = m_mesh->connectivity().nodeToCellMatrix();
    const auto& node_to_face_matrix                  = m_mesh->connectivity().nodeToFaceMatrix();
    const auto& face_to_cell_matrix                  = m_mesh->connectivity().faceToCellMatrix();

    CellValuePerNode<double> w_rj{m_mesh->connectivity()};
    FaceValuePerNode<double> w_rl{m_mesh->connectivity()};

    const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();
    auto project_to_face = [&](const TinyVector<Dimension>& x, const FaceId face_id) -> const TinyVector<Dimension> {
      TinyVector<Dimension> proj;
      const TinyVector<Dimension> nil = primal_nlr(face_id, 0);
      proj                            = x - dot((x - xl[face_id]), nil) * nil;
      return proj;
    };

    for (size_t i = 0; i < w_rl.numberOfValues(); ++i) {
      w_rl[i] = std::numeric_limits<double>::signaling_NaN();
    }

    for (NodeId i_node = 0; i_node < m_mesh->numberOfNodes(); ++i_node) {
      SmallVector<double> b{Dimension + 1};
      b[0] = 1;
      for (size_t i = 1; i < Dimension + 1; i++) {
        b[i] = xr[i_node][i - 1];
      }
      const auto& node_to_cell = node_to_cell_matrix[i_node];

      if (not m_primal_node_is_on_boundary[i_node]) {
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
        for (size_t j = 0; j < node_to_cell.size(); j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }

        SmallVector<double> x{node_to_cell.size()};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
      } else {
        int nb_face_used = 0;
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            nb_face_used++;
          }
        }
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size() + nb_face_used};
        for (size_t j = 0; j < node_to_cell.size() + nb_face_used; j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          int cpt_face = 0;
          for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[i_node][i_face];
            if (m_primal_face_is_on_boundary[face_id]) {
              if (m_primal_face_is_symmetry[face_id]) {
                for (size_t j = 0; j < face_to_cell_matrix[face_id].size(); ++j) {
                  const CellId cell_id                 = face_to_cell_matrix[face_id][j];
                  TinyVector<Dimension> xproj          = project_to_face(xj[cell_id], face_id);
                  A(i, node_to_cell.size() + cpt_face) = xproj[i - 1];
                }
              } else {
                A(i, node_to_cell.size() + cpt_face) = xl[face_id][i - 1];
              }
              cpt_face++;
            }
          }
        }

        SmallVector<double> x{node_to_cell.size() + nb_face_used};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
        int cpt_face = node_to_cell.size();
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            w_rl(i_node, i_face) = x[cpt_face++];
          }
        }
      }
    }
    m_w_rj = w_rj;
    m_w_rl = w_rl;
  }
};

class VectorDiamondSchemeHandler::IVectorDiamondScheme
{
 public:
  virtual std::shared_ptr<const IDiscreteFunction> getSolution() const     = 0;
  virtual std::shared_ptr<const IDiscreteFunction> getDualSolution() const = 0;
  virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> apply()
    const = 0;
  // virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
  // computeEnergyUpdate() const = 0;

  IVectorDiamondScheme()          = default;
  virtual ~IVectorDiamondScheme() = default;
};

template <size_t Dimension>
class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSchemeHandler::IVectorDiamondScheme
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_solution;
  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_dual_solution;
  //  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_energy_delta;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const TinyVector<Dimension>>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list,
                               const Array<const TinyVector<Dimension>>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NormalStrainBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const TinyVector<Dimension>>&
    valueList() const
    {
      return m_value_list;
    }

    NormalStrainBoundaryCondition(const Array<const FaceId>& face_list,
                                  const Array<const TinyVector<Dimension>>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NormalStrainBoundaryCondition() = default;
  };

  class SymmetryBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

   public:
    SymmetryBoundaryCondition(const Array<const FaceId>& face_list) : m_face_list{face_list} {}

    ~SymmetryBoundaryCondition() = default;
  };

 public:
  std::shared_ptr<const IDiscreteFunction>
  getSolution() const final
  {
    return m_solution;
  }

  std::shared_ptr<const IDiscreteFunction>
  getDualSolution() const final
  {
    return m_dual_solution;
  }

  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
  apply() const final
  {
    return {m_solution, m_dual_solution};
  }

  VectorDiamondScheme(const std::shared_ptr<const MeshType>& mesh,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambdab,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mub,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambda,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mu,
                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& source,
                      const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    Assert(mesh == alpha->mesh());
    Assert(mesh == source->mesh());
    Assert(dual_lambda->mesh() == dual_mu->mesh());
    Assert(dual_lambdab->mesh() == dual_mu->mesh());
    Assert(dual_mub->mesh() == dual_mu->mesh());
    Assert(DualMeshManager::instance().getDiamondDualMesh(*mesh) == dual_mu->mesh(),
           "diffusion coefficient is not defined on the dual mesh!");

    using MeshDataType = MeshData<Dimension>;

    using BoundaryCondition =
      std::variant<DirichletBoundaryCondition, NormalStrainBoundaryCondition, SymmetryBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    NodeValue<bool> is_dirichlet{mesh->connectivity()};
    is_dirichlet.fill(false);
    NodeValue<TinyVector<Dimension>> dirichlet_value{mesh->connectivity()};
    {
      TinyVector<Dimension> nan_tiny_vector;
      for (size_t i = 0; i < Dimension; ++i) {
        nan_tiny_vector[i] = std::numeric_limits<double>::signaling_NaN();
      }
      dirichlet_value.fill(nan_tiny_vector);
    }

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFlatFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFlatFaceBoundary(*mesh, sym_bc_descriptor.boundaryDescriptor());
          boundary_condition_list.push_back(SymmetryBoundaryCondition{mesh_face_boundary.faceList()});
        } else {
          throw NotImplementedError("Symmetry conditions are not supported in 1d");
        }

        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "dirichlet") {
          if constexpr (Dimension > 1) {
            MeshFaceBoundary<Dimension> mesh_face_boundary =
              getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

            const FunctionSymbolId g_id                   = dirichlet_bc_descriptor.rhsSymbolId();
            Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
              TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                            mesh_face_boundary.faceList());
            boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});
          } else {
            throw NotImplementedError("Neumann conditions are not supported in 1d");
          }
        } else if (dirichlet_bc_descriptor.name() == "normal_strain") {
          if constexpr (Dimension > 1) {
            MeshFaceBoundary<Dimension> mesh_face_boundary =
              getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

            const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();

            Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
              TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                            mesh_face_boundary.faceList());
            boundary_condition_list.push_back(NormalStrainBoundaryCondition{mesh_face_boundary.faceList(), value_list});

          } else {
            throw NotImplementedError("Normal strain conditions are not supported in 1d");
          }
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for elasticity equation";
        throw NormalError(error_msg.str());
      }
    }

    if constexpr (Dimension > 1) {
      const CellValue<const size_t> cell_dof_number = [&] {
        CellValue<size_t> compute_cell_dof_number{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { compute_cell_dof_number[cell_id] = cell_id; });
        return compute_cell_dof_number;
      }();
      size_t number_of_dof = mesh->numberOfCells();

      const FaceValue<const size_t> face_dof_number = [&] {
        FaceValue<size_t> compute_face_dof_number{mesh->connectivity()};
        compute_face_dof_number.fill(std::numeric_limits<size_t>::max());
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>) or
                            (std::is_same_v<T, SymmetryBoundaryCondition>) or
                            (std::is_same_v<T, DirichletBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id = face_list[i_face];
                  if (compute_face_dof_number[face_id] != std::numeric_limits<size_t>::max()) {
                    std::ostringstream os;
                    os << "The face " << face_id << " is used at least twice for boundary conditions";
                    throw NormalError(os.str());
                  } else {
                    compute_face_dof_number[face_id] = number_of_dof++;
                  }
                }
              }
            },
            boundary_condition);
        }

        return compute_face_dof_number;
      }();

      const auto& primal_face_to_node_matrix             = mesh->connectivity().faceToNodeMatrix();
      const auto& face_to_cell_matrix                    = mesh->connectivity().faceToCellMatrix();
      const FaceValue<const bool> primal_face_is_neumann = [&] {
        FaceValue<bool> face_is_neumann{mesh->connectivity()};
        face_is_neumann.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id     = face_list[i_face];
                  face_is_neumann[face_id] = true;
                }
              }
            },
            boundary_condition);
        }

        return face_is_neumann;
      }();

      const FaceValue<const bool> primal_face_is_symmetry = [&] {
        FaceValue<bool> face_is_symmetry{mesh->connectivity()};
        face_is_symmetry.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, SymmetryBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id      = face_list[i_face];
                  face_is_symmetry[face_id] = true;
                }
              }
            },
            boundary_condition);
        }

        return face_is_symmetry;
      }();

      NodeValue<bool> primal_node_is_on_boundary(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of node_is_on_boundary is incorrect");
      }

      primal_node_is_on_boundary.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (face_to_cell_matrix[face_id].size() == 1) {
          for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
            NodeId node_id                      = primal_face_to_node_matrix[face_id][i_node];
            primal_node_is_on_boundary[node_id] = true;
          }
        }
      }

      FaceValue<bool> primal_face_is_on_boundary(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of face_is_on_boundary is incorrect");
      }

      primal_face_is_on_boundary.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (face_to_cell_matrix[face_id].size() == 1) {
          primal_face_is_on_boundary[face_id] = true;
        }
      }

      FaceValue<bool> primal_face_is_dirichlet(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of face_is_neumann is incorrect");
      }

      primal_face_is_dirichlet.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        primal_face_is_dirichlet[face_id] = (primal_face_is_on_boundary[face_id] &&
                                             (!primal_face_is_neumann[face_id]) && (!primal_face_is_symmetry[face_id]));
      }

      InterpolationWeightsManager iwm(mesh, primal_face_is_on_boundary, primal_node_is_on_boundary,
                                      primal_face_is_symmetry);
      iwm.compute();
      CellValuePerNode<double> w_rj = iwm.wrj();
      FaceValuePerNode<double> w_rl = iwm.wrl();

      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

      const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
      const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
      // const auto& node_to_cell_matrix                                = mesh->connectivity().nodeToCellMatrix();
      const auto& node_to_face_matrix                                = mesh->connectivity().nodeToFaceMatrix();
      const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();

      {
        std::shared_ptr diamond_mesh = DualMeshManager::instance().getDiamondDualMesh(*mesh);

        MeshDataType& diamond_mesh_data = MeshDataManager::instance().getMeshData(*diamond_mesh);

        std::shared_ptr mapper =
          DualConnectivityManager::instance().getPrimalToDiamondDualConnectivityDataMapper(mesh->connectivity());

        CellValue<const double> dual_muj      = dual_mu->cellValues();
        CellValue<const double> dual_lambdaj  = dual_lambda->cellValues();
        CellValue<const double> dual_mubj     = dual_mub->cellValues();
        CellValue<const double> dual_lambdabj = dual_lambdab->cellValues();

        CellValue<const TinyVector<Dimension>> fj = source->cellValues();
        // for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
        //   std::cout << xj[cell_id] << "-> fj[" << cell_id << "]=" << fj[cell_id] << '\n';
        // }

        const CellValue<const double> dual_Vj = diamond_mesh_data.Vj();

        const FaceValue<const double> mes_l = [&] {
          if constexpr (Dimension == 1) {
            FaceValue<double> compute_mes_l{mesh->connectivity()};
            compute_mes_l.fill(1);
            return compute_mes_l;
          } else {
            return mesh_data.ll();
          }
        }();

        const CellValue<const double> dual_mes_l_j = [=] {
          CellValue<double> compute_mes_j{diamond_mesh->connectivity()};
          mapper->toDualCell(mes_l, compute_mes_j);

          return compute_mes_j;
        }();

        const CellValue<const double> primal_Vj = mesh_data.Vj();

        CellValue<const FaceId> dual_cell_face_id = [=]() {
          CellValue<FaceId> computed_dual_cell_face_id{diamond_mesh->connectivity()};
          FaceValue<FaceId> primal_face_id{mesh->connectivity()};
          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) { primal_face_id[face_id] = face_id; });

          mapper->toDualCell(primal_face_id, computed_dual_cell_face_id);

          return computed_dual_cell_face_id;
        }();

        FaceValue<const CellId> face_dual_cell_id = [=]() {
          FaceValue<CellId> computed_face_dual_cell_id{mesh->connectivity()};
          CellValue<CellId> dual_cell_id{diamond_mesh->connectivity()};
          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { dual_cell_id[cell_id] = cell_id; });

          mapper->fromDualCell(dual_cell_id, computed_face_dual_cell_id);

          return computed_face_dual_cell_id;
        }();

        NodeValue<const NodeId> dual_node_primal_node_id = [=]() {
          CellValue<NodeId> cell_ignored_id{mesh->connectivity()};
          cell_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

          NodeValue<NodeId> node_primal_id{mesh->connectivity()};

          parallel_for(
            mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { node_primal_id[node_id] = node_id; });

          NodeValue<NodeId> computed_dual_node_primal_node_id{diamond_mesh->connectivity()};

          mapper->toDualNode(node_primal_id, cell_ignored_id, computed_dual_node_primal_node_id);

          return computed_dual_node_primal_node_id;
        }();

        CellValue<NodeId> primal_cell_dual_node_id = [=]() {
          CellValue<NodeId> cell_id{mesh->connectivity()};
          NodeValue<NodeId> node_ignored_id{mesh->connectivity()};
          node_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

          NodeValue<NodeId> dual_node_id{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { dual_node_id[node_id] = node_id; });

          CellValue<NodeId> computed_primal_cell_dual_node_id{mesh->connectivity()};

          mapper->fromDualNode(dual_node_id, node_ignored_id, cell_id);

          return cell_id;
        }();
        const auto& dual_Cjr                     = diamond_mesh_data.Cjr();
        FaceValue<TinyVector<Dimension>> dualClj = [&] {
          FaceValue<TinyVector<Dimension>> computedClj{mesh->connectivity()};
          const auto& dual_node_to_cell_matrix = diamond_mesh->connectivity().nodeToCellMatrix();
          const auto& dual_cell_to_node_matrix = diamond_mesh->connectivity().cellToNodeMatrix();
          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
              const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
              for (size_t i = 0; i < primal_face_to_cell.size(); i++) {
                CellId cell_id            = primal_face_to_cell[i];
                const NodeId dual_node_id = primal_cell_dual_node_id[cell_id];
                for (size_t i_dual_cell = 0; i_dual_cell < dual_node_to_cell_matrix[dual_node_id].size();
                     i_dual_cell++) {
                  const CellId dual_cell_id = dual_node_to_cell_matrix[dual_node_id][i_dual_cell];
                  if (face_dual_cell_id[face_id] == dual_cell_id) {
                    for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size();
                         i_dual_node++) {
                      const NodeId final_dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
                      if (final_dual_node_id == dual_node_id) {
                        computedClj[face_id] = dual_Cjr(dual_cell_id, i_dual_node);
                      }
                    }
                  }
                }
              }
            });
          return computedClj;
        }();

        FaceValue<TinyVector<Dimension>> nlj = [&] {
          FaceValue<TinyVector<Dimension>> computedNlj{mesh->connectivity()};
          parallel_for(
            mesh->numberOfFaces(),
            PUGS_LAMBDA(FaceId face_id) { computedNlj[face_id] = 1. / l2Norm(dualClj[face_id]) * dualClj[face_id]; });
          return computedNlj;
        }();

        FaceValue<const double> alpha_lambda_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_lambdaj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        FaceValue<const double> alpha_mu_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_muj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        FaceValue<const double> alpha_lambdab_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_lambdabj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        FaceValue<const double> alpha_mub_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_mubj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        const TinyMatrix<Dimension> I = identity;

        const Array<int> non_zeros{number_of_dof * Dimension};
        non_zeros.fill(Dimension * Dimension);
        CRSMatrixDescriptor<double> S(number_of_dof * Dimension, number_of_dof * Dimension, non_zeros);

        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          const double beta_mu_l          = l2Norm(dualClj[face_id]) * alpha_mu_l[face_id] * mes_l[face_id];
          const double beta_lambda_l      = l2Norm(dualClj[face_id]) * alpha_lambda_l[face_id] * mes_l[face_id];
          const double beta_mub_l         = l2Norm(dualClj[face_id]) * alpha_mub_l[face_id] * mes_l[face_id];
          const double beta_lambdab_l     = l2Norm(dualClj[face_id]) * alpha_lambdab_l[face_id] * mes_l[face_id];
          const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
          for (size_t i_cell = 0; i_cell < primal_face_to_cell.size(); ++i_cell) {
            const CellId i_id                      = primal_face_to_cell[i_cell];
            const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

            const TinyVector<Dimension> nil = [&] {
              if (is_face_reversed_for_cell_i) {
                return -nlj[face_id];
              } else {
                return nlj[face_id];
              }
            }();
            TinyMatrix<Dimension> M =
              beta_mu_l * I + beta_mu_l * tensorProduct(nil, nil) + beta_lambda_l * tensorProduct(nil, nil);
            TinyMatrix<Dimension> Mb =
              beta_mub_l * I + beta_mub_l * tensorProduct(nil, nil) + beta_lambdab_l * tensorProduct(nil, nil);
            TinyMatrix<Dimension> N = 1.e0 * tensorProduct(nil, nil);
            double coef_adim        = beta_mu_l + beta_lambdab_l;
            for (size_t j_cell = 0; j_cell < primal_face_to_cell.size(); ++j_cell) {
              const CellId j_id = primal_face_to_cell[j_cell];
              if (i_cell == j_cell) {
                for (size_t i = 0; i < Dimension; ++i) {
                  for (size_t j = 0; j < Dimension; ++j) {
                    S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) += M(i, j);
                    if (primal_face_is_neumann[face_id]) {
                      S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) -=
                        1.e0 * Mb(i, j);
                      // S(face_dof_number[face_id] * Dimension + i, face_dof_number[face_id] * Dimension + j) +=
                      //   1.e-10 * Mb(i, j);
                    }
                    if (primal_face_is_symmetry[face_id]) {
                      S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) +=
                        ((i == j) ? -coef_adim : 0) + coef_adim * N(i, j);
                      S(face_dof_number[face_id] * Dimension + i, face_dof_number[face_id] * Dimension + j) +=
                        (i == j) ? coef_adim : 0;
                    }
                  }
                }
              } else {
                for (size_t i = 0; i < Dimension; ++i) {
                  for (size_t j = 0; j < Dimension; ++j) {
                    S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) -= M(i, j);
                  }
                }
              }
            }
          }
        }

        for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
          for (size_t i = 0; i < Dimension; ++i) {
            const size_t j = cell_dof_number[cell_id] * Dimension + i;
            S(j, j) += (*alpha)[cell_id] * primal_Vj[cell_id];
          }
        }

        const auto& dual_cell_to_node_matrix   = diamond_mesh->connectivity().cellToNodeMatrix();
        const auto& primal_node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          const double alpha_mu_face_id      = mes_l[face_id] * alpha_mu_l[face_id];
          const double alpha_lambda_face_id  = mes_l[face_id] * alpha_lambda_l[face_id];
          const double alpha_mub_face_id     = mes_l[face_id] * alpha_mub_l[face_id];
          const double alpha_lambdab_face_id = mes_l[face_id] * alpha_lambdab_l[face_id];

          for (size_t i_face_cell = 0; i_face_cell < face_to_cell_matrix[face_id].size(); ++i_face_cell) {
            CellId i_id                            = face_to_cell_matrix[face_id][i_face_cell];
            const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

            for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
              NodeId node_id = primal_face_to_node_matrix[face_id][i_node];

              const TinyVector<Dimension> nil = [&] {
                if (is_face_reversed_for_cell_i) {
                  return -nlj[face_id];
                } else {
                  return nlj[face_id];
                }
              }();

              CellId dual_cell_id = face_dual_cell_id[face_id];

              for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size(); ++i_dual_node) {
                const NodeId dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
                if (dual_node_primal_node_id[dual_node_id] == node_id) {
                  const TinyVector<Dimension> Clr = dual_Cjr(dual_cell_id, i_dual_node);

                  TinyMatrix<Dimension> M = alpha_mu_face_id * dot(Clr, nil) * I +
                                            alpha_mu_face_id * tensorProduct(Clr, nil) +
                                            alpha_lambda_face_id * tensorProduct(nil, Clr);
                  TinyMatrix<Dimension> Mb = alpha_mub_face_id * dot(Clr, nil) * I +
                                             alpha_mub_face_id * tensorProduct(Clr, nil) +
                                             alpha_lambdab_face_id * tensorProduct(nil, Clr);

                  for (size_t j_cell = 0; j_cell < primal_node_to_cell_matrix[node_id].size(); ++j_cell) {
                    CellId j_id = primal_node_to_cell_matrix[node_id][j_cell];
                    for (size_t i = 0; i < Dimension; ++i) {
                      for (size_t j = 0; j < Dimension; ++j) {
                        S((cell_dof_number[i_id] * Dimension) + i, (cell_dof_number[j_id] * Dimension) + j) -=
                          w_rj(node_id, j_cell) * M(i, j);
                        if (primal_face_is_neumann[face_id]) {
                          S(face_dof_number[face_id] * Dimension + i, cell_dof_number[j_id] * Dimension + j) +=
                            1.e0 * w_rj(node_id, j_cell) * Mb(i, j);
                        }
                      }
                    }
                  }
                  if (primal_node_is_on_boundary[node_id]) {
                    for (size_t l_face = 0; l_face < node_to_face_matrix[node_id].size(); ++l_face) {
                      FaceId l_id = node_to_face_matrix[node_id][l_face];
                      if (primal_face_is_on_boundary[l_id]) {
                        for (size_t i = 0; i < Dimension; ++i) {
                          for (size_t j = 0; j < Dimension; ++j) {
                            // Mb?
                            S(cell_dof_number[i_id] * Dimension + i, face_dof_number[l_id] * Dimension + j) -=
                              w_rl(node_id, l_face) * M(i, j);
                          }
                        }
                        if (primal_face_is_neumann[face_id]) {
                          for (size_t i = 0; i < Dimension; ++i) {
                            for (size_t j = 0; j < Dimension; ++j) {
                              S(face_dof_number[face_id] * Dimension + i, face_dof_number[l_id] * Dimension + j) +=
                                1.e0 * w_rl(node_id, l_face) * Mb(i, j);
                            }
                          }
                        }
                      }
                    }
                  }
                }
              }
              //            }
            }
          }
        }
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          if (primal_face_is_dirichlet[face_id]) {
            for (size_t i = 0; i < Dimension; ++i) {
              S(face_dof_number[face_id] * Dimension + i, face_dof_number[face_id] * Dimension + i) += 1.e0;
            }
          }
        }

        Vector<double> b{number_of_dof * Dimension};
        b = zero;
        for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
          for (size_t i = 0; i < Dimension; ++i) {
            b[(cell_dof_number[cell_id] * Dimension) + i] = primal_Vj[cell_id] * fj[cell_id][i];
          }
        }

        // Dirichlet
        NodeValue<bool> node_tag{mesh->connectivity()};
        node_tag.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr (std::is_same_v<T, DirichletBoundaryCondition>) {
                const auto& face_list  = bc.faceList();
                const auto& value_list = bc.valueList();
                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id = face_list[i_face];

                  for (size_t i = 0; i < Dimension; ++i) {
                    b[(face_dof_number[face_id] * Dimension) + i] += 1.e0 * value_list[i_face][i];
                  }
                }
              }
            },
            boundary_condition);
        }

        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
                const auto& face_list  = bc.faceList();
                const auto& value_list = bc.valueList();
                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  FaceId face_id = face_list[i_face];
                  for (size_t i = 0; i < Dimension; ++i) {
                    b[face_dof_number[face_id] * Dimension + i] +=
                      1.e0 * mes_l[face_id] * value_list[i_face][i];   // sign
                  }
                }
              }
            },
            boundary_condition);
        }

        CRSMatrix A{S.getCRSMatrix()};
        Vector<double> U{number_of_dof * Dimension};
        U        = zero;
        Vector r = A * U - b;
        std::cout << "initial (real) residu = " << std::sqrt(dot(r, r)) << '\n';

        LinearSolver solver;
        solver.solveLocalSystem(A, U, b);

        r = A * U - b;

        std::cout << "final (real) residu = " << std::sqrt(dot(r, r)) << '\n';

        m_solution     = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(mesh);
        auto& solution = *m_solution;
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
            for (size_t i = 0; i < Dimension; ++i) {
              solution[cell_id][i] = U[(cell_dof_number[cell_id] * Dimension) + i];
            }
          });

        m_dual_solution     = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(diamond_mesh);
        auto& dual_solution = *m_dual_solution;
        dual_solution.fill(zero);
        const auto& face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();
        for (CellId cell_id = 0; cell_id < diamond_mesh->numberOfCells(); ++cell_id) {
          const FaceId face_id = dual_cell_face_id[cell_id];
          if (primal_face_is_on_boundary[face_id]) {
            for (size_t i = 0; i < Dimension; ++i) {
              dual_solution[cell_id][i] = U[(face_dof_number[face_id] * Dimension) + i];
            }
          } else {
            CellId cell_id1 = face_to_cell_matrix[face_id][0];
            CellId cell_id2 = face_to_cell_matrix[face_id][1];
            for (size_t i = 0; i < Dimension; ++i) {
              dual_solution[cell_id][i] =
                0.5 * (U[(cell_dof_number[cell_id1] * Dimension) + i] + U[(cell_dof_number[cell_id2] * Dimension) + i]);
            }
          }
        }
      }
      // provide a source for E?
      // computeEnergyUpdate(mesh, dual_lambdab, dual_mub, m_solution,
      //                     m_dual_solution,   // f,
      //                     bc_descriptor_list);
      // computeEnergyUpdate(mesh, alpha, dual_lambdab, dual_mub, dual_lambda, dual_mu, m_solution,
      //                     m_dual_solution,   // f,
      //                     bc_descriptor_list);
    } else {
      throw NotImplementedError("not done in 1d");
    }
  }
};
// NEW CLASS
template <size_t Dimension>
class EnergyComputerHandler::InterpolationWeightsManager
{
 private:
  std::shared_ptr<const Mesh<Connectivity<Dimension>>> m_mesh;
  FaceValue<const bool> m_primal_face_is_on_boundary;
  NodeValue<const bool> m_primal_node_is_on_boundary;
  FaceValue<const bool> m_primal_face_is_symmetry;
  CellValuePerNode<double> m_w_rj;
  FaceValuePerNode<double> m_w_rl;

 public:
  InterpolationWeightsManager(std::shared_ptr<const Mesh<Connectivity<Dimension>>> mesh,
                              FaceValue<const bool> primal_face_is_on_boundary,
                              NodeValue<const bool> primal_node_is_on_boundary,
                              FaceValue<const bool> primal_face_is_symmetry)
    : m_mesh(mesh),
      m_primal_face_is_on_boundary(primal_face_is_on_boundary),
      m_primal_node_is_on_boundary(primal_node_is_on_boundary),
      m_primal_face_is_symmetry(primal_face_is_symmetry)
  {}
  ~InterpolationWeightsManager() = default;
  CellValuePerNode<double>&
  wrj()
  {
    return m_w_rj;
  }
  FaceValuePerNode<double>&
  wrl()
  {
    return m_w_rl;
  }
  void
  compute()
  {
    using MeshDataType      = MeshData<Dimension>;
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = m_mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
    const auto& node_to_cell_matrix                  = m_mesh->connectivity().nodeToCellMatrix();
    const auto& node_to_face_matrix                  = m_mesh->connectivity().nodeToFaceMatrix();
    const auto& face_to_cell_matrix                  = m_mesh->connectivity().faceToCellMatrix();

    CellValuePerNode<double> w_rj{m_mesh->connectivity()};
    FaceValuePerNode<double> w_rl{m_mesh->connectivity()};

    const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();
    auto project_to_face = [&](const TinyVector<Dimension>& x, const FaceId face_id) -> const TinyVector<Dimension> {
      TinyVector<Dimension> proj;
      const TinyVector<Dimension> nil = primal_nlr(face_id, 0);
      proj                            = x - dot((x - xl[face_id]), nil) * nil;
      return proj;
    };

    for (size_t i = 0; i < w_rl.numberOfValues(); ++i) {
      w_rl[i] = std::numeric_limits<double>::signaling_NaN();
    }

    for (NodeId i_node = 0; i_node < m_mesh->numberOfNodes(); ++i_node) {
      SmallVector<double> b{Dimension + 1};
      b[0] = 1;
      for (size_t i = 1; i < Dimension + 1; i++) {
        b[i] = xr[i_node][i - 1];
      }
      const auto& node_to_cell = node_to_cell_matrix[i_node];

      if (not m_primal_node_is_on_boundary[i_node]) {
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
        for (size_t j = 0; j < node_to_cell.size(); j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }

        SmallVector<double> x{node_to_cell.size()};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
      } else {
        int nb_face_used = 0;
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            nb_face_used++;
          }
        }
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size() + nb_face_used};
        for (size_t j = 0; j < node_to_cell.size() + nb_face_used; j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          int cpt_face = 0;
          for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[i_node][i_face];
            if (m_primal_face_is_on_boundary[face_id]) {
              if (m_primal_face_is_symmetry[face_id]) {
                for (size_t j = 0; j < face_to_cell_matrix[face_id].size(); ++j) {
                  const CellId cell_id                 = face_to_cell_matrix[face_id][j];
                  TinyVector<Dimension> xproj          = project_to_face(xj[cell_id], face_id);
                  A(i, node_to_cell.size() + cpt_face) = xproj[i - 1];
                }
              } else {
                A(i, node_to_cell.size() + cpt_face) = xl[face_id][i - 1];
              }
              cpt_face++;
            }
          }
        }

        SmallVector<double> x{node_to_cell.size() + nb_face_used};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
        int cpt_face = node_to_cell.size();
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            w_rl(i_node, i_face) = x[cpt_face++];
          }
        }
      }
    }
    m_w_rj = w_rj;
    m_w_rl = w_rl;
  }
};
class EnergyComputerHandler::IEnergyComputer
{
 public:
  // virtual std::shared_ptr<const IDiscreteFunction> getSolution() const     = 0;
  // virtual std::shared_ptr<const IDiscreteFunction> getDualSolution() const = 0;
  // virtual std::shared_ptr<const IDiscreteFunction> apply() const           = 0;
  virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> apply()
    const = 0;

  IEnergyComputer()          = default;
  virtual ~IEnergyComputer() = default;
};

template <size_t Dimension>
class EnergyComputerHandler::EnergyComputer : public EnergyComputerHandler::IEnergyComputer
{
 private:
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_solution;
  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_dual_solution;
  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_energy_delta;

  class DirichletBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const TinyVector<Dimension>>&
    valueList() const
    {
      return m_value_list;
    }

    DirichletBoundaryCondition(const Array<const FaceId>& face_list,
                               const Array<const TinyVector<Dimension>>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~DirichletBoundaryCondition() = default;
  };

  class NormalStrainBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

    const Array<const TinyVector<Dimension>>&
    valueList() const
    {
      return m_value_list;
    }

    NormalStrainBoundaryCondition(const Array<const FaceId>& face_list,
                                  const Array<const TinyVector<Dimension>>& value_list)
      : m_value_list{value_list}, m_face_list{face_list}
    {
      Assert(m_value_list.size() == m_face_list.size());
    }

    ~NormalStrainBoundaryCondition() = default;
  };

  class SymmetryBoundaryCondition
  {
   private:
    const Array<const TinyVector<Dimension>> m_value_list;
    const Array<const FaceId> m_face_list;

   public:
    const Array<const FaceId>&
    faceList() const
    {
      return m_face_list;
    }

   public:
    SymmetryBoundaryCondition(const Array<const FaceId>& face_list) : m_face_list{face_list} {}

    ~SymmetryBoundaryCondition() = default;
  };

 public:
  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
  apply() const final
  {
    return {m_energy_delta, m_dual_solution};
  }

  // std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
  // computeEnergyUpdate() const final
  // {
  //   m_scheme->computeEnergyUpdate(mesh, dual_lambdab, dual_mub, m_solution,
  //                                 m_dual_solution,   // f,
  //                                 bc_descriptor_list);

  //   return {m_dual_solution, m_energy_delta};
  // }

  // compute the fluxes
  // std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>
  EnergyComputer(const std::shared_ptr<const MeshType>& mesh,
                 // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambdab,
                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mub,
                 // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambda,
                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& U,
                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& dual_U,
                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& source,
                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    // Assert(mesh == alpha->mesh());
    Assert(mesh == U->mesh());
    // Assert(dual_lambda->mesh() == dual_mu->mesh());
    Assert(dual_lambdab->mesh() == dual_U->mesh());
    Assert(U->mesh() == source->mesh());
    Assert(dual_lambdab->mesh() == dual_mub->mesh());
    Assert(DualMeshManager::instance().getDiamondDualMesh(*mesh) == dual_mub->mesh(),
           "diffusion coefficient is not defined on the dual mesh!");

    using MeshDataType = MeshData<Dimension>;

    using BoundaryCondition =
      std::variant<DirichletBoundaryCondition, NormalStrainBoundaryCondition, SymmetryBoundaryCondition>;

    using BoundaryConditionList = std::vector<BoundaryCondition>;

    BoundaryConditionList boundary_condition_list;

    NodeValue<bool> is_dirichlet{mesh->connectivity()};
    is_dirichlet.fill(false);
    NodeValue<TinyVector<Dimension>> dirichlet_value{mesh->connectivity()};
    {
      TinyVector<Dimension> nan_tiny_vector;
      for (size_t i = 0; i < Dimension; ++i) {
        nan_tiny_vector[i] = std::numeric_limits<double>::signaling_NaN();
      }
      dirichlet_value.fill(nan_tiny_vector);
    }

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);

        if constexpr (Dimension > 1) {
          MeshFlatFaceBoundary<Dimension> mesh_face_boundary =
            getMeshFlatFaceBoundary(*mesh, sym_bc_descriptor.boundaryDescriptor());
          boundary_condition_list.push_back(SymmetryBoundaryCondition{mesh_face_boundary.faceList()});
        } else {
          throw NotImplementedError("Symmetry conditions are not supported in 1d");
        }

        break;
      }
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "dirichlet") {
          if constexpr (Dimension > 1) {
            MeshFaceBoundary<Dimension> mesh_face_boundary =
              getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

            const FunctionSymbolId g_id                   = dirichlet_bc_descriptor.rhsSymbolId();
            Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
              TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                            mesh_face_boundary.faceList());
            boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});
          } else {
            throw NotImplementedError("Neumann conditions are not supported in 1d");
          }
        } else if (dirichlet_bc_descriptor.name() == "normal_strain") {
          if constexpr (Dimension > 1) {
            MeshFaceBoundary<Dimension> mesh_face_boundary =
              getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

            MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

            const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();

            Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
              TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id, mesh_data.xl(),
                                                                            mesh_face_boundary.faceList());
            boundary_condition_list.push_back(NormalStrainBoundaryCondition{mesh_face_boundary.faceList(), value_list});

          } else {
            throw NotImplementedError("Normal strain conditions are not supported in 1d");
          }
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for elasticity equation";
        throw NormalError(error_msg.str());
      }
    }

    if constexpr (Dimension > 1) {
      const CellValue<const size_t> cell_dof_number = [&] {
        CellValue<size_t> compute_cell_dof_number{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { compute_cell_dof_number[cell_id] = cell_id; });
        return compute_cell_dof_number;
      }();
      size_t number_of_dof = mesh->numberOfCells();

      const FaceValue<const size_t> face_dof_number = [&] {
        FaceValue<size_t> compute_face_dof_number{mesh->connectivity()};
        compute_face_dof_number.fill(std::numeric_limits<size_t>::max());
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>) or
                            (std::is_same_v<T, SymmetryBoundaryCondition>) or
                            (std::is_same_v<T, DirichletBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id = face_list[i_face];
                  if (compute_face_dof_number[face_id] != std::numeric_limits<size_t>::max()) {
                    std::ostringstream os;
                    os << "The face " << face_id << " is used at least twice for boundary conditions";
                    throw NormalError(os.str());
                  } else {
                    compute_face_dof_number[face_id] = number_of_dof++;
                  }
                }
              }
            },
            boundary_condition);
        }

        return compute_face_dof_number;
      }();

      const auto& primal_face_to_node_matrix             = mesh->connectivity().faceToNodeMatrix();
      const auto& face_to_cell_matrix                    = mesh->connectivity().faceToCellMatrix();
      const FaceValue<const bool> primal_face_is_neumann = [&] {
        FaceValue<bool> face_is_neumann{mesh->connectivity()};
        face_is_neumann.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id     = face_list[i_face];
                  face_is_neumann[face_id] = true;
                }
              }
            },
            boundary_condition);
        }

        return face_is_neumann;
      }();

      const FaceValue<const bool> primal_face_is_symmetry = [&] {
        FaceValue<bool> face_is_symmetry{mesh->connectivity()};
        face_is_symmetry.fill(false);
        for (const auto& boundary_condition : boundary_condition_list) {
          std::visit(
            [&](auto&& bc) {
              using T = std::decay_t<decltype(bc)>;
              if constexpr ((std::is_same_v<T, SymmetryBoundaryCondition>)) {
                const auto& face_list = bc.faceList();

                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  const FaceId face_id      = face_list[i_face];
                  face_is_symmetry[face_id] = true;
                }
              }
            },
            boundary_condition);
        }

        return face_is_symmetry;
      }();

      NodeValue<bool> primal_node_is_on_boundary(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of node_is_on_boundary is incorrect");
      }

      primal_node_is_on_boundary.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (face_to_cell_matrix[face_id].size() == 1) {
          for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
            NodeId node_id                      = primal_face_to_node_matrix[face_id][i_node];
            primal_node_is_on_boundary[node_id] = true;
          }
        }
      }

      FaceValue<bool> primal_face_is_on_boundary(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of face_is_on_boundary is incorrect");
      }

      primal_face_is_on_boundary.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        if (face_to_cell_matrix[face_id].size() == 1) {
          primal_face_is_on_boundary[face_id] = true;
        }
      }

      FaceValue<bool> primal_face_is_dirichlet(mesh->connectivity());
      if (parallel::size() > 1) {
        throw NotImplementedError("Calculation of face_is_neumann is incorrect");
      }

      primal_face_is_dirichlet.fill(false);
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        primal_face_is_dirichlet[face_id] = (primal_face_is_on_boundary[face_id] &&
                                             (!primal_face_is_neumann[face_id]) && (!primal_face_is_symmetry[face_id]));
      }

      InterpolationWeightsManager iwm(mesh, primal_face_is_on_boundary, primal_node_is_on_boundary,
                                      primal_face_is_symmetry);
      iwm.compute();
      CellValuePerNode<double> w_rj = iwm.wrj();
      FaceValuePerNode<double> w_rl = iwm.wrl();

      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

      const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
      const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
      // const auto& node_to_cell_matrix                                = mesh->connectivity().nodeToCellMatrix();
      const auto& node_to_face_matrix                                = mesh->connectivity().nodeToFaceMatrix();
      const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();

      {
        std::shared_ptr diamond_mesh = DualMeshManager::instance().getDiamondDualMesh(*mesh);

        MeshDataType& diamond_mesh_data = MeshDataManager::instance().getMeshData(*diamond_mesh);

        std::shared_ptr mapper =
          DualConnectivityManager::instance().getPrimalToDiamondDualConnectivityDataMapper(mesh->connectivity());

        CellValue<const double> dual_mubj     = dual_mub->cellValues();
        CellValue<const double> dual_lambdabj = dual_lambdab->cellValues();
        // attention, fj not in this context
        CellValue<const TinyVector<Dimension>> velocity = U->cellValues();
        // CellValue<const TinyVector<Dimension>> dual_velocity = dual_U->cellValues();

        const CellValue<const double> dual_Vj = diamond_mesh_data.Vj();

        const FaceValue<const double> mes_l = [&] {
          if constexpr (Dimension == 1) {
            FaceValue<double> compute_mes_l{mesh->connectivity()};
            compute_mes_l.fill(1);
            return compute_mes_l;
          } else {
            return mesh_data.ll();
          }
        }();

        const CellValue<const double> dual_mes_l_j = [=] {
          CellValue<double> compute_mes_j{diamond_mesh->connectivity()};
          mapper->toDualCell(mes_l, compute_mes_j);

          return compute_mes_j;
        }();

        const CellValue<const double> primal_Vj   = mesh_data.Vj();
        FaceValue<const CellId> face_dual_cell_id = [=]() {
          FaceValue<CellId> computed_face_dual_cell_id{mesh->connectivity()};
          CellValue<CellId> dual_cell_id{diamond_mesh->connectivity()};
          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { dual_cell_id[cell_id] = cell_id; });

          mapper->fromDualCell(dual_cell_id, computed_face_dual_cell_id);

          return computed_face_dual_cell_id;
        }();

        NodeValue<const NodeId> dual_node_primal_node_id = [=]() {
          CellValue<NodeId> cell_ignored_id{mesh->connectivity()};
          cell_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

          NodeValue<NodeId> node_primal_id{mesh->connectivity()};

          parallel_for(
            mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { node_primal_id[node_id] = node_id; });

          NodeValue<NodeId> computed_dual_node_primal_node_id{diamond_mesh->connectivity()};

          mapper->toDualNode(node_primal_id, cell_ignored_id, computed_dual_node_primal_node_id);

          return computed_dual_node_primal_node_id;
        }();

        CellValue<NodeId> primal_cell_dual_node_id = [=]() {
          CellValue<NodeId> cell_id{mesh->connectivity()};
          NodeValue<NodeId> node_ignored_id{mesh->connectivity()};
          node_ignored_id.fill(NodeId{std::numeric_limits<unsigned int>::max()});

          NodeValue<NodeId> dual_node_id{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfNodes(), PUGS_LAMBDA(NodeId node_id) { dual_node_id[node_id] = node_id; });

          CellValue<NodeId> computed_primal_cell_dual_node_id{mesh->connectivity()};

          mapper->fromDualNode(dual_node_id, node_ignored_id, cell_id);

          return cell_id;
        }();
        const auto& dual_Cjr                     = diamond_mesh_data.Cjr();
        FaceValue<TinyVector<Dimension>> dualClj = [&] {
          FaceValue<TinyVector<Dimension>> computedClj{mesh->connectivity()};
          const auto& dual_node_to_cell_matrix = diamond_mesh->connectivity().nodeToCellMatrix();
          const auto& dual_cell_to_node_matrix = diamond_mesh->connectivity().cellToNodeMatrix();
          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
              const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
              for (size_t i = 0; i < primal_face_to_cell.size(); i++) {
                CellId cell_id            = primal_face_to_cell[i];
                const NodeId dual_node_id = primal_cell_dual_node_id[cell_id];
                for (size_t i_dual_cell = 0; i_dual_cell < dual_node_to_cell_matrix[dual_node_id].size();
                     i_dual_cell++) {
                  const CellId dual_cell_id = dual_node_to_cell_matrix[dual_node_id][i_dual_cell];
                  if (face_dual_cell_id[face_id] == dual_cell_id) {
                    for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size();
                         i_dual_node++) {
                      const NodeId final_dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
                      if (final_dual_node_id == dual_node_id) {
                        computedClj[face_id] = dual_Cjr(dual_cell_id, i_dual_node);
                      }
                    }
                  }
                }
              }
            });
          return computedClj;
        }();

        FaceValue<TinyVector<Dimension>> nlj = [&] {
          FaceValue<TinyVector<Dimension>> computedNlj{mesh->connectivity()};
          parallel_for(
            mesh->numberOfFaces(),
            PUGS_LAMBDA(FaceId face_id) { computedNlj[face_id] = 1. / l2Norm(dualClj[face_id]) * dualClj[face_id]; });
          return computedNlj;
        }();

        // FaceValue<const double> alpha_lambda_l = [&] {
        //   CellValue<double> alpha_j{diamond_mesh->connectivity()};

        //   parallel_for(
        //     diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
        //       alpha_j[diamond_cell_id] = dual_lambdaj[diamond_cell_id] / dual_Vj[diamond_cell_id];
        //     });

        //   FaceValue<double> computed_alpha_l{mesh->connectivity()};
        //   mapper->fromDualCell(alpha_j, computed_alpha_l);
        //   return computed_alpha_l;
        // }();

        // FaceValue<const double> alpha_mu_l = [&] {
        //   CellValue<double> alpha_j{diamond_mesh->connectivity()};

        //   parallel_for(
        //     diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
        //       alpha_j[diamond_cell_id] = dual_muj[diamond_cell_id] / dual_Vj[diamond_cell_id];
        //     });

        //   FaceValue<double> computed_alpha_l{mesh->connectivity()};
        //   mapper->fromDualCell(alpha_j, computed_alpha_l);
        //   return computed_alpha_l;
        // }();

        FaceValue<const double> alpha_lambdab_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_lambdabj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        FaceValue<const double> alpha_mub_l = [&] {
          CellValue<double> alpha_j{diamond_mesh->connectivity()};

          parallel_for(
            diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
              alpha_j[diamond_cell_id] = dual_mubj[diamond_cell_id] / dual_Vj[diamond_cell_id];
            });

          FaceValue<double> computed_alpha_l{mesh->connectivity()};
          mapper->fromDualCell(alpha_j, computed_alpha_l);
          return computed_alpha_l;
        }();

        const TinyMatrix<Dimension> I             = identity;
        CellValue<const FaceId> dual_cell_face_id = [=]() {
          CellValue<FaceId> computed_dual_cell_face_id{diamond_mesh->connectivity()};
          FaceValue<FaceId> primal_face_id{mesh->connectivity()};
          parallel_for(
            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) { primal_face_id[face_id] = face_id; });

          mapper->toDualCell(primal_face_id, computed_dual_cell_face_id);

          return computed_dual_cell_face_id;
        }();

        m_dual_solution = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(diamond_mesh);
        // m_solution           = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(mesh);
        // const auto& solution = *U;
        auto& dual_solution = *dual_U;
        //  dual_solution.fill(zero);
        // const auto& face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();
        // for (CellId cell_id = 0; cell_id < diamond_mesh->numberOfCells(); ++cell_id) {
        //   const FaceId face_id = dual_cell_face_id[cell_id];
        //   CellId cell_id1      = face_to_cell_matrix[face_id][0];
        //   if (primal_face_is_on_boundary[face_id]) {
        //     for (size_t i = 0; i < Dimension; ++i) {
        //       // A revoir!!
        //       dual_solution[cell_id][i] = solution[cell_id1][i];
        //     }
        //   } else {
        //     CellId cell_id1 = face_to_cell_matrix[face_id][0];
        //     CellId cell_id2 = face_to_cell_matrix[face_id][1];
        //     for (size_t i = 0; i < Dimension; ++i) {
        //       dual_solution[cell_id][i] = 0.5 * (solution[cell_id1][i] + solution[cell_id2][i]);
        //     }
        //   }
        // }

        // const Array<int> non_zeros{number_of_dof * Dimension};
        // non_zeros.fill(Dimension * Dimension);
        // CRSMatrixDescriptor<double> S(number_of_dof * Dimension, number_of_dof * Dimension, non_zeros);
        // Begining of main
        CellValuePerFace<double> flux{mesh->connectivity()};
        parallel_for(
          flux.numberOfValues(), PUGS_LAMBDA(size_t jl) { flux[jl] = 0; });

        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          // const double beta_mu_l     = l2Norm(dualClj[face_id]) * alpha_mu_l[face_id] * mes_l[face_id];
          // const double beta_lambda_l = l2Norm(dualClj[face_id]) * alpha_lambda_l[face_id] * mes_l[face_id];
          const double beta_mub_l         = l2Norm(dualClj[face_id]) * alpha_mub_l[face_id] * mes_l[face_id];
          const double beta_lambdab_l     = l2Norm(dualClj[face_id]) * alpha_lambdab_l[face_id] * mes_l[face_id];
          const auto& primal_face_to_cell = face_to_cell_matrix[face_id];
          for (size_t i_cell = 0; i_cell < primal_face_to_cell.size(); ++i_cell) {
            const CellId i_id                      = primal_face_to_cell[i_cell];
            const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

            const TinyVector<Dimension> nil = [&] {
              if (is_face_reversed_for_cell_i) {
                return -nlj[face_id];
              } else {
                return nlj[face_id];
              }
            }();
            TinyMatrix<Dimension> M =
              beta_mub_l * I + beta_mub_l * tensorProduct(nil, nil) + beta_lambdab_l * tensorProduct(nil, nil);
            //            TinyMatrix<Dimension, double> Mb =
            //  beta_mub_l * I + beta_mub_l * tensorProduct(nil, nil) + beta_lambdab_l * tensorProduct(nil, nil);
            // TinyMatrix<Dimension, double> N = 1.e0 * tensorProduct(nil, nil);

            for (size_t j_cell = 0; j_cell < primal_face_to_cell.size(); ++j_cell) {
              const CellId j_id = primal_face_to_cell[j_cell];
              if (i_cell == j_cell) {
                flux(face_id, i_cell) += dot(M * velocity[i_id], dual_solution[face_dual_cell_id[face_id]]);
              } else {
                flux(face_id, i_cell) -= dot(M * velocity[j_id], dual_solution[face_dual_cell_id[face_id]]);
              }
            }
          }
        }

        const auto& dual_cell_to_node_matrix   = diamond_mesh->connectivity().cellToNodeMatrix();
        const auto& primal_node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          // const double alpha_mu_face_id     = mes_l[face_id] * alpha_mu_l[face_id];
          // const double alpha_lambda_face_id = mes_l[face_id] * alpha_lambda_l[face_id];
          const double alpha_mub_face_id     = mes_l[face_id] * alpha_mub_l[face_id];
          const double alpha_lambdab_face_id = mes_l[face_id] * alpha_lambdab_l[face_id];

          for (size_t i_face_cell = 0; i_face_cell < face_to_cell_matrix[face_id].size(); ++i_face_cell) {
            CellId i_id                            = face_to_cell_matrix[face_id][i_face_cell];
            const bool is_face_reversed_for_cell_i = (dot(dualClj[face_id], xl[face_id] - xj[i_id]) < 0);

            for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
              NodeId node_id = primal_face_to_node_matrix[face_id][i_node];

              const TinyVector<Dimension> nil = [&] {
                if (is_face_reversed_for_cell_i) {
                  return -nlj[face_id];
                } else {
                  return nlj[face_id];
                }
              }();

              CellId dual_cell_id = face_dual_cell_id[face_id];

              for (size_t i_dual_node = 0; i_dual_node < dual_cell_to_node_matrix[dual_cell_id].size(); ++i_dual_node) {
                const NodeId dual_node_id = dual_cell_to_node_matrix[dual_cell_id][i_dual_node];
                if (dual_node_primal_node_id[dual_node_id] == node_id) {
                  const TinyVector<Dimension> Clr = dual_Cjr(dual_cell_id, i_dual_node);

                  TinyMatrix<Dimension> M = alpha_mub_face_id * dot(Clr, nil) * I +
                                            alpha_mub_face_id * tensorProduct(Clr, nil) +
                                            alpha_lambdab_face_id * tensorProduct(nil, Clr);
                  // TinyMatrix<Dimension, double> Mb = alpha_mub_face_id * dot(Clr, nil) * I +
                  //                                    alpha_mub_face_id * tensorProduct(Clr, nil) +
                  //                                    alpha_lambdab_face_id * tensorProduct(nil, Clr);

                  for (size_t j_cell = 0; j_cell < primal_node_to_cell_matrix[node_id].size(); ++j_cell) {
                    CellId j_id = primal_node_to_cell_matrix[node_id][j_cell];
                    flux(face_id, i_face_cell) -=
                      w_rj(node_id, j_cell) * dot(M * velocity[j_id], dual_solution[dual_cell_id]);
                  }
                  if (primal_node_is_on_boundary[node_id]) {
                    for (size_t l_face = 0; l_face < node_to_face_matrix[node_id].size(); ++l_face) {
                      FaceId l_id = node_to_face_matrix[node_id][l_face];
                      if (primal_face_is_on_boundary[l_id]) {
                        flux(face_id, i_face_cell) -=
                          w_rl(node_id, l_face) *
                          dot(M * dual_solution[face_dual_cell_id[l_id]], dual_solution[dual_cell_id]);
                      }
                    }
                  }
                }
              }
              //            }
            }
          }
        }
        // for (const auto& boundary_condition : boundary_condition_list) {
        //   std::visit(
        //     [&](auto&& bc) {
        //       using T = std::decay_t<decltype(bc)>;
        //       if constexpr ((std::is_same_v<T, NormalStrainBoundaryCondition>)) {
        //         const auto& face_list  = bc.faceList();
        //         const auto& value_list = bc.valueList();
        //         for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
        //           FaceId face_id      = face_list[i_face];
        //           CellId dual_cell_id = face_dual_cell_id[face_id];
        //           flux(face_id, 0)    = mes_l[face_id] * dot(value_list[i_face], dual_solution[dual_cell_id]);   //
        //           sign
        //         }
        //       }
        //     },
        //     boundary_condition);
        // }
        // for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        //   if (face_to_cell_matrix[face_id].size() == 2) {
        //     CellId i_id = face_to_cell_matrix[face_id][0];
        //     CellId j_id = face_to_cell_matrix[face_id][1];
        //     if (flux(face_id, 0) != -flux(face_id, 1)) {
        //       std::cout << "flux(" << i_id << "," << face_id << ")=" << flux(face_id, 0) << " not equal to -flux("
        //                 << j_id << "," << face_id << ")=" << -flux(face_id, 1) << "\n";
        //     }
        //   }
        //   // exit(0);
        // }
        // Assemble
        m_energy_delta     = std::make_shared<DiscreteFunctionP0<Dimension, double>>(mesh);
        auto& energy_delta = *m_energy_delta;
        // CellValue<const TinyVector<Dimension>> fj = source->cellValues();

        double sum_deltae = 0.;
        for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); ++cell_id) {
          energy_delta[cell_id] = 0.;   // dot(fj[cell_id], velocity[cell_id]);
          sum_deltae += energy_delta[cell_id];
        }
        // CellValue<double>& deltae = m_energy_delta->cellValues();
        for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
          for (size_t j = 0; j < face_to_cell_matrix[face_id].size(); j++) {
            CellId i_id = face_to_cell_matrix[face_id][j];
            energy_delta[i_id] -= flux(face_id, j) / primal_Vj[i_id];
            sum_deltae -= flux(face_id, j);
          }
          // exit(0);
        }

        std::cout << "sum deltaej " << sum_deltae << "\n";
      }
    } else {
      throw NotImplementedError("not done in 1d");
    }
    //    return m_energy_delta;
  }
};

std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
VectorDiamondSchemeHandler::apply() const
{
  return m_scheme->apply();
}

std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
EnergyComputerHandler::computeEnergyUpdate() const
{
  return m_energy_computer->apply();
}

std::shared_ptr<const IDiscreteFunction>
VectorDiamondSchemeHandler::solution() const
{
  return m_scheme->getSolution();
}

std::shared_ptr<const IDiscreteFunction>
VectorDiamondSchemeHandler::dual_solution() const
{
  return m_scheme->getDualSolution();
}

VectorDiamondSchemeHandler::VectorDiamondSchemeHandler(
  const std::shared_ptr<const IDiscreteFunction>& alpha,
  const std::shared_ptr<const IDiscreteFunction>& dual_lambdab,
  const std::shared_ptr<const IDiscreteFunction>& dual_mub,
  const std::shared_ptr<const IDiscreteFunction>& dual_lambda,
  const std::shared_ptr<const IDiscreteFunction>& dual_mu,
  const std::shared_ptr<const IDiscreteFunction>& f,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  const std::shared_ptr i_mesh = getCommonMesh({alpha, f});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }
  const std::shared_ptr i_dual_mesh = getCommonMesh({dual_lambda, dual_lambdab, dual_mu, dual_mub});
  if (not i_dual_mesh) {
    throw NormalError("dual discrete functions are not defined on the same mesh");
  }
  checkDiscretizationType({alpha, dual_lambdab, dual_mub, dual_lambda, dual_mu, f}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    using MeshType                   = Mesh<Connectivity<1>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<1, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<1, TinyVector<1>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    m_scheme =
      std::make_unique<VectorDiamondScheme<1>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(
                                                 dual_lambdab),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambda),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mu),
                                               std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<2, TinyVector<2>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_scheme =
      std::make_unique<VectorDiamondScheme<2>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(
                                                 dual_lambdab),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambda),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mu),
                                               std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  case 3: {
    using MeshType                   = Mesh<Connectivity<3>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<3, TinyVector<3>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_scheme =
      std::make_unique<VectorDiamondScheme<3>>(mesh, std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(alpha),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(
                                                 dual_lambdab),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambda),
                                               std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mu),
                                               std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(f),
                                               bc_descriptor_list);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

VectorDiamondSchemeHandler::~VectorDiamondSchemeHandler() = default;

EnergyComputerHandler::EnergyComputerHandler(
  const std::shared_ptr<const IDiscreteFunction>& dual_lambdab,
  const std::shared_ptr<const IDiscreteFunction>& dual_mub,
  const std::shared_ptr<const IDiscreteFunction>& U,
  const std::shared_ptr<const IDiscreteFunction>& dual_U,
  const std::shared_ptr<const IDiscreteFunction>& source,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  const std::shared_ptr i_mesh = getCommonMesh({U, source});
  if (not i_mesh) {
    throw NormalError("primal discrete functions are not defined on the same mesh");
  }
  const std::shared_ptr i_dual_mesh = getCommonMesh({dual_lambdab, dual_mub, dual_U});
  if (not i_dual_mesh) {
    throw NormalError("dual discrete functions are not defined on the same mesh");
  }
  checkDiscretizationType({dual_lambdab, dual_mub, dual_U, U, source}, DiscreteFunctionType::P0);

  switch (i_mesh->dimension()) {
  case 1: {
    using MeshType                   = Mesh<Connectivity<1>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<1, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<1, TinyVector<1>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_energy_computer =
      std::make_unique<EnergyComputer<1>>(mesh,
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(dual_U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
                                          bc_descriptor_list);
    break;
  }
  case 2: {
    using MeshType                   = Mesh<Connectivity<2>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<2, TinyVector<2>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_energy_computer =
      std::make_unique<EnergyComputer<2>>(mesh,
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(dual_U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
                                          bc_descriptor_list);
    break;
  }
  case 3: {
    using MeshType                   = Mesh<Connectivity<3>>;
    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;
    using DiscreteVectorFunctionType = DiscreteFunctionP0<3, TinyVector<3>>;

    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

    if (DualMeshManager::instance().getDiamondDualMesh(*mesh) != i_dual_mesh) {
      throw NormalError("dual variables are is not defined on the diamond dual of the primal mesh");
    }

    m_energy_computer =
      std::make_unique<EnergyComputer<3>>(mesh,
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(dual_U),
                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
                                          bc_descriptor_list);
    break;
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

EnergyComputerHandler::~EnergyComputerHandler() = default;
