diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index f47c26354426ce62402425cedd3bb432da7c3b86..84d263e6817151729a2ec3345aee46826b6adc03 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -493,7 +493,7 @@ SchemeModule::SchemeModule()
 
   this->_addBuiltinFunction("moleculardiffusion",
                             std::make_shared<BuiltinFunctionEmbedder<std::tuple<
-                              std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>,
+                              std::shared_ptr<const IDiscreteFunction>,
                               std::shared_ptr<const IDiscreteFunction>>(const std::shared_ptr<const IDiscreteFunction>&,
                                                                         const std::shared_ptr<const IDiscreteFunction>&,
                                                                         const std::shared_ptr<const IDiscreteFunction>&,
@@ -510,7 +510,6 @@ SchemeModule::SchemeModule()
                                  const std::shared_ptr<const IDiscreteFunction> f,
                                  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                    bc_descriptor_list) -> const std::tuple<std::shared_ptr<const IDiscreteFunction>,
-                                                                           std::shared_ptr<const IDiscreteFunction>,
                                                                            std::shared_ptr<const IDiscreteFunction>> {
                                 return VectorDiamondSchemeHandler{alpha, lambdab,           mub, lambda, mu,
                                                                   f,     bc_descriptor_list}
@@ -519,6 +518,25 @@ SchemeModule::SchemeModule()
 
                               ));
 
+  this->_addBuiltinFunction(
+    "energybalance",
+    std::make_shared<BuiltinFunctionEmbedder<std::tuple<
+      std::shared_ptr<const IDiscreteFunction>,
+      std::shared_ptr<const IDiscreteFunction>>(const std::shared_ptr<const IDiscreteFunction>&,
+                                                const std::shared_ptr<const IDiscreteFunction>&,
+                                                const std::shared_ptr<const IDiscreteFunction>&,
+                                                const std::shared_ptr<const IDiscreteFunction>&,
+                                                const std::vector<
+                                                  std::shared_ptr<const IBoundaryConditionDescriptor>>&)>>(
+      [](const std::shared_ptr<const IDiscreteFunction> lambdab, const std::shared_ptr<const IDiscreteFunction> mub,
+         const std::shared_ptr<const IDiscreteFunction> U, const std::shared_ptr<const IDiscreteFunction> source,
+         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
+        -> const std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> {
+        return EnergyComputerHandler{lambdab, mub, U, source, bc_descriptor_list}.computeEnergyUpdate();
+      }
+
+      ));
+
   this->_addBuiltinFunction("heat2",
                             std::make_shared<BuiltinFunctionEmbedder<
                               void(std::shared_ptr<const IMesh>,
diff --git a/src/scheme/VectorDiamondScheme.cpp b/src/scheme/VectorDiamondScheme.cpp
index eb8429eaf3902902fa3d5bac37541c83982b78e9..87173047e0bf25d5c7f0980de7b9c414922b55ea 100644
--- a/src/scheme/VectorDiamondScheme.cpp
+++ b/src/scheme/VectorDiamondScheme.cpp
@@ -158,10 +158,10 @@ class VectorDiamondSchemeHandler::IVectorDiamondScheme
  public:
   virtual std::shared_ptr<const IDiscreteFunction> getSolution() const     = 0;
   virtual std::shared_ptr<const IDiscreteFunction> getDualSolution() const = 0;
-  virtual std::tuple<std::shared_ptr<const IDiscreteFunction>,
-                     std::shared_ptr<const IDiscreteFunction>,
-                     std::shared_ptr<const IDiscreteFunction>>
-  apply() const = 0;
+  virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> apply()
+    const = 0;
+  // virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
+  // computeEnergyUpdate() const = 0;
 
   IVectorDiamondScheme()          = default;
   virtual ~IVectorDiamondScheme() = default;
@@ -177,7 +177,7 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
 
   std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_solution;
   std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_dual_solution;
-  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_energy_delta;
+  //  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_energy_delta;
 
   class DirichletBoundaryCondition
   {
@@ -269,12 +269,10 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
     return m_dual_solution;
   }
 
-  std::tuple<std::shared_ptr<const IDiscreteFunction>,
-             std::shared_ptr<const IDiscreteFunction>,
-             std::shared_ptr<const IDiscreteFunction>>
+  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
   apply() const final
   {
-    return {m_solution, m_dual_solution, m_energy_delta};
+    return {m_solution, m_dual_solution};
   }
 
   VectorDiamondScheme(const std::shared_ptr<const MeshType>& mesh,
@@ -924,6 +922,7 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
               solution[cell_id][i] = U[(cell_dof_number[cell_id] * Dimension) + i];
             }
           });
+
         m_dual_solution     = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(diamond_mesh);
         auto& dual_solution = *m_dual_solution;
         dual_solution.fill(zero);
@@ -945,9 +944,9 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
         }
       }
       // provide a source for E?
-      computeEnergyUpdate(mesh, dual_lambdab, dual_mub, m_solution,
-                          m_dual_solution,   // f,
-                          bc_descriptor_list);
+      // computeEnergyUpdate(mesh, dual_lambdab, dual_mub, m_solution,
+      //                     m_dual_solution,   // f,
+      //                     bc_descriptor_list);
       // computeEnergyUpdate(mesh, alpha, dual_lambdab, dual_mub, dual_lambda, dual_mu, m_solution,
       //                     m_dual_solution,   // f,
       //                     bc_descriptor_list);
@@ -955,25 +954,294 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
       throw NotImplementedError("not done in 1d");
     }
   }
+};
+// NEW CLASS
+template <size_t Dimension>
+class EnergyComputerHandler::InterpolationWeightsManager
+{
+ private:
+  std::shared_ptr<const Mesh<Connectivity<Dimension>>> m_mesh;
+  FaceValue<const bool> m_primal_face_is_on_boundary;
+  NodeValue<const bool> m_primal_node_is_on_boundary;
+  FaceValue<const bool> m_primal_face_is_symmetry;
+  CellValuePerNode<double> m_w_rj;
+  FaceValuePerNode<double> m_w_rl;
+
+ public:
+  InterpolationWeightsManager(std::shared_ptr<const Mesh<Connectivity<Dimension>>> mesh,
+                              FaceValue<const bool> primal_face_is_on_boundary,
+                              NodeValue<const bool> primal_node_is_on_boundary,
+                              FaceValue<const bool> primal_face_is_symmetry)
+    : m_mesh(mesh),
+      m_primal_face_is_on_boundary(primal_face_is_on_boundary),
+      m_primal_node_is_on_boundary(primal_node_is_on_boundary),
+      m_primal_face_is_symmetry(primal_face_is_symmetry)
+  {}
+  ~InterpolationWeightsManager() = default;
+  CellValuePerNode<double>&
+  wrj()
+  {
+    return m_w_rj;
+  }
+  FaceValuePerNode<double>&
+  wrl()
+  {
+    return m_w_rl;
+  }
+  void
+  compute()
+  {
+    using MeshDataType      = MeshData<Dimension>;
+    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);
+
+    const NodeValue<const TinyVector<Dimension>>& xr = m_mesh->xr();
+
+    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
+    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
+    const auto& node_to_cell_matrix                  = m_mesh->connectivity().nodeToCellMatrix();
+    const auto& node_to_face_matrix                  = m_mesh->connectivity().nodeToFaceMatrix();
+    const auto& face_to_cell_matrix                  = m_mesh->connectivity().faceToCellMatrix();
+
+    CellValuePerNode<double> w_rj{m_mesh->connectivity()};
+    FaceValuePerNode<double> w_rl{m_mesh->connectivity()};
+
+    const NodeValuePerFace<const TinyVector<Dimension>> primal_nlr = mesh_data.nlr();
+    auto project_to_face = [&](const TinyVector<Dimension>& x, const FaceId face_id) -> const TinyVector<Dimension> {
+      TinyVector<Dimension> proj;
+      const TinyVector<Dimension> nil = primal_nlr(face_id, 0);
+      proj                            = x - dot((x - xl[face_id]), nil) * nil;
+      return proj;
+    };
+
+    for (size_t i = 0; i < w_rl.numberOfValues(); ++i) {
+      w_rl[i] = std::numeric_limits<double>::signaling_NaN();
+    }
+
+    for (NodeId i_node = 0; i_node < m_mesh->numberOfNodes(); ++i_node) {
+      SmallVector<double> b{Dimension + 1};
+      b[0] = 1;
+      for (size_t i = 1; i < Dimension + 1; i++) {
+        b[i] = xr[i_node][i - 1];
+      }
+      const auto& node_to_cell = node_to_cell_matrix[i_node];
+
+      if (not m_primal_node_is_on_boundary[i_node]) {
+        SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
+        for (size_t j = 0; j < node_to_cell.size(); j++) {
+          A(0, j) = 1;
+        }
+        for (size_t i = 1; i < Dimension + 1; i++) {
+          for (size_t j = 0; j < node_to_cell.size(); j++) {
+            const CellId J = node_to_cell[j];
+            A(i, j)        = xj[J][i - 1];
+          }
+        }
+
+        SmallVector<double> x{node_to_cell.size()};
+        x = zero;
+
+        LeastSquareSolver ls_solver;
+        ls_solver.solveLocalSystem(A, x, b);
+
+        for (size_t j = 0; j < node_to_cell.size(); j++) {
+          w_rj(i_node, j) = x[j];
+        }
+      } else {
+        int nb_face_used = 0;
+        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
+          FaceId face_id = node_to_face_matrix[i_node][i_face];
+          if (m_primal_face_is_on_boundary[face_id]) {
+            nb_face_used++;
+          }
+        }
+        SmallMatrix<double> A{Dimension + 1, node_to_cell.size() + nb_face_used};
+        for (size_t j = 0; j < node_to_cell.size() + nb_face_used; j++) {
+          A(0, j) = 1;
+        }
+        for (size_t i = 1; i < Dimension + 1; i++) {
+          for (size_t j = 0; j < node_to_cell.size(); j++) {
+            const CellId J = node_to_cell[j];
+            A(i, j)        = xj[J][i - 1];
+          }
+        }
+        for (size_t i = 1; i < Dimension + 1; i++) {
+          int cpt_face = 0;
+          for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
+            FaceId face_id = node_to_face_matrix[i_node][i_face];
+            if (m_primal_face_is_on_boundary[face_id]) {
+              if (m_primal_face_is_symmetry[face_id]) {
+                for (size_t j = 0; j < face_to_cell_matrix[face_id].size(); ++j) {
+                  const CellId cell_id                 = face_to_cell_matrix[face_id][j];
+                  TinyVector<Dimension> xproj          = project_to_face(xj[cell_id], face_id);
+                  A(i, node_to_cell.size() + cpt_face) = xproj[i - 1];
+                }
+              } else {
+                A(i, node_to_cell.size() + cpt_face) = xl[face_id][i - 1];
+              }
+              cpt_face++;
+            }
+          }
+        }
+
+        SmallVector<double> x{node_to_cell.size() + nb_face_used};
+        x = zero;
+
+        LeastSquareSolver ls_solver;
+        ls_solver.solveLocalSystem(A, x, b);
+
+        for (size_t j = 0; j < node_to_cell.size(); j++) {
+          w_rj(i_node, j) = x[j];
+        }
+        int cpt_face = node_to_cell.size();
+        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
+          FaceId face_id = node_to_face_matrix[i_node][i_face];
+          if (m_primal_face_is_on_boundary[face_id]) {
+            w_rl(i_node, i_face) = x[cpt_face++];
+          }
+        }
+      }
+    }
+    m_w_rj = w_rj;
+    m_w_rl = w_rl;
+  }
+};
+class EnergyComputerHandler::IEnergyComputer
+{
+ public:
+  // virtual std::shared_ptr<const IDiscreteFunction> getSolution() const     = 0;
+  // virtual std::shared_ptr<const IDiscreteFunction> getDualSolution() const = 0;
+  // virtual std::shared_ptr<const IDiscreteFunction> apply() const           = 0;
+  virtual std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> apply()
+    const = 0;
+
+  IEnergyComputer()          = default;
+  virtual ~IEnergyComputer() = default;
+};
+
+template <size_t Dimension>
+class EnergyComputerHandler::EnergyComputer : public EnergyComputerHandler::IEnergyComputer
+{
+ private:
+  using ConnectivityType = Connectivity<Dimension>;
+  using MeshType         = Mesh<ConnectivityType>;
+  using MeshDataType     = MeshData<Dimension>;
+
+  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_solution;
+  std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>> m_dual_solution;
+  std::shared_ptr<const DiscreteFunctionP0<Dimension, double>> m_energy_delta;
+
+  class DirichletBoundaryCondition
+  {
+   private:
+    const Array<const TinyVector<Dimension>> m_value_list;
+    const Array<const FaceId> m_face_list;
+
+   public:
+    const Array<const FaceId>&
+    faceList() const
+    {
+      return m_face_list;
+    }
+
+    const Array<const TinyVector<Dimension>>&
+    valueList() const
+    {
+      return m_value_list;
+    }
+
+    DirichletBoundaryCondition(const Array<const FaceId>& face_list,
+                               const Array<const TinyVector<Dimension>>& value_list)
+      : m_value_list{value_list}, m_face_list{face_list}
+    {
+      Assert(m_value_list.size() == m_face_list.size());
+    }
+
+    ~DirichletBoundaryCondition() = default;
+  };
+
+  class NormalStrainBoundaryCondition
+  {
+   private:
+    const Array<const TinyVector<Dimension>> m_value_list;
+    const Array<const FaceId> m_face_list;
+
+   public:
+    const Array<const FaceId>&
+    faceList() const
+    {
+      return m_face_list;
+    }
+
+    const Array<const TinyVector<Dimension>>&
+    valueList() const
+    {
+      return m_value_list;
+    }
+
+    NormalStrainBoundaryCondition(const Array<const FaceId>& face_list,
+                                  const Array<const TinyVector<Dimension>>& value_list)
+      : m_value_list{value_list}, m_face_list{face_list}
+    {
+      Assert(m_value_list.size() == m_face_list.size());
+    }
+
+    ~NormalStrainBoundaryCondition() = default;
+  };
+
+  class SymmetryBoundaryCondition
+  {
+   private:
+    const Array<const TinyVector<Dimension>> m_value_list;
+    const Array<const FaceId> m_face_list;
+
+   public:
+    const Array<const FaceId>&
+    faceList() const
+    {
+      return m_face_list;
+    }
+
+   public:
+    SymmetryBoundaryCondition(const Array<const FaceId>& face_list) : m_face_list{face_list} {}
+
+    ~SymmetryBoundaryCondition() = default;
+  };
+
+ public:
+  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
+  apply() const final
+  {
+    return {m_energy_delta, m_dual_solution};
+  }
+
+  // std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
+  // computeEnergyUpdate() const final
+  // {
+  //   m_scheme->computeEnergyUpdate(mesh, dual_lambdab, dual_mub, m_solution,
+  //                                 m_dual_solution,   // f,
+  //                                 bc_descriptor_list);
+
+  //   return {m_dual_solution, m_energy_delta};
+  // }
+
   // compute the fluxes
-  void   // std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>
-  computeEnergyUpdate(const std::shared_ptr<const MeshType>& mesh,
-                      // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
-                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambdab,
-                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mub,
-                      // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambda,
-                      // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mu,
-                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& U,
-                      // const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& source,
-                      const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& dual_U,
-                      const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
+  // std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>
+  EnergyComputer(const std::shared_ptr<const MeshType>& mesh,
+                 // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& alpha,
+                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambdab,
+                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mub,
+                 // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_lambda,
+                 // const std::shared_ptr<const DiscreteFunctionP0<Dimension, double>>& dual_mu,
+                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& U,
+                 const std::shared_ptr<const DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>& source,
+                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
   {
     // Assert(mesh == alpha->mesh());
     Assert(mesh == U->mesh());
     // Assert(dual_lambda->mesh() == dual_mu->mesh());
     // Assert(dual_lambdab->mesh() == dual_mu->mesh());
-    Assert(dual_mub->mesh() == dual_U->mesh());
-    Assert(dual_lambdab->mesh() == dual_U->mesh());
+    Assert(U->mesh() == source->mesh());
+    Assert(dual_lambdab->mesh() == dual_mub->mesh());
     if (DiamondDualMeshManager::instance().getDiamondDualMesh(mesh) != dual_mub->mesh()) {
       throw NormalError("diffusion coefficient is not defined on the dual mesh!");
     }
@@ -1216,8 +1484,8 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
         CellValue<const double> dual_mubj     = dual_mub->cellValues();
         CellValue<const double> dual_lambdabj = dual_lambdab->cellValues();
         // attention, fj not in this context
-        CellValue<const TinyVector<Dimension>> velocity      = U->cellValues();
-        CellValue<const TinyVector<Dimension>> dual_velocity = dual_U->cellValues();
+        CellValue<const TinyVector<Dimension>> velocity = U->cellValues();
+        // CellValue<const TinyVector<Dimension>> dual_velocity = dual_U->cellValues();
 
         const CellValue<const double> dual_Vj = diamond_mesh_data.Vj();
 
@@ -1371,7 +1639,40 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
           return computed_alpha_l;
         }();
 
-        const TinyMatrix<Dimension> I = identity;
+        const TinyMatrix<Dimension> I             = identity;
+        CellValue<const FaceId> dual_cell_face_id = [=]() {
+          CellValue<FaceId> computed_dual_cell_face_id{diamond_mesh->connectivity()};
+          FaceValue<FaceId> primal_face_id{mesh->connectivity()};
+          parallel_for(
+            mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) { primal_face_id[face_id] = face_id; });
+
+          mapper->toDualCell(primal_face_id, computed_dual_cell_face_id);
+
+          return computed_dual_cell_face_id;
+        }();
+
+        m_dual_solution      = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(diamond_mesh);
+        m_solution           = std::make_shared<DiscreteFunctionP0<Dimension, TinyVector<Dimension>>>(mesh);
+        const auto& solution = *U;
+        auto& dual_solution  = *m_dual_solution;
+        dual_solution.fill(zero);
+        const auto& face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();
+        for (CellId cell_id = 0; cell_id < diamond_mesh->numberOfCells(); ++cell_id) {
+          const FaceId face_id = dual_cell_face_id[cell_id];
+          CellId cell_id1      = face_to_cell_matrix[face_id][0];
+          if (primal_face_is_on_boundary[face_id]) {
+            for (size_t i = 0; i < Dimension; ++i) {
+              // A revoir!!
+              dual_solution[cell_id][i] = solution[cell_id1][i];
+            }
+          } else {
+            CellId cell_id1 = face_to_cell_matrix[face_id][0];
+            CellId cell_id2 = face_to_cell_matrix[face_id][1];
+            for (size_t i = 0; i < Dimension; ++i) {
+              dual_solution[cell_id][i] = 0.5 * (solution[cell_id1][i] + solution[cell_id2][i]);
+            }
+          }
+        }
 
         // const Array<int> non_zeros{number_of_dof * Dimension};
         // non_zeros.fill(Dimension * Dimension);
@@ -1407,9 +1708,9 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
             for (size_t j_cell = 0; j_cell < primal_face_to_cell.size(); ++j_cell) {
               const CellId j_id = primal_face_to_cell[j_cell];
               if (i_cell == j_cell) {
-                flux(face_id, i_cell) += dot(M * velocity[i_id], dual_velocity[face_dual_cell_id[face_id]]);
+                flux(face_id, i_cell) += dot(M * velocity[i_id], dual_solution[face_dual_cell_id[face_id]]);
               } else {
-                flux(face_id, i_cell) -= dot(M * velocity[j_id], dual_velocity[face_dual_cell_id[face_id]]);
+                flux(face_id, i_cell) -= dot(M * velocity[j_id], dual_solution[face_dual_cell_id[face_id]]);
               }
             }
           }
@@ -1455,7 +1756,7 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
                   for (size_t j_cell = 0; j_cell < primal_node_to_cell_matrix[node_id].size(); ++j_cell) {
                     CellId j_id = primal_node_to_cell_matrix[node_id][j_cell];
                     flux(face_id, i_face_cell) -=
-                      w_rj(node_id, j_cell) * dot(M * velocity[j_id], dual_velocity[dual_cell_id]);
+                      w_rj(node_id, j_cell) * dot(M * velocity[j_id], dual_solution[dual_cell_id]);
                   }
                   if (primal_node_is_on_boundary[node_id]) {
                     for (size_t l_face = 0; l_face < node_to_face_matrix[node_id].size(); ++l_face) {
@@ -1463,7 +1764,7 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
                       if (primal_face_is_on_boundary[l_id]) {
                         flux(face_id, i_face_cell) -=
                           w_rl(node_id, l_face) *
-                          dot(M * dual_velocity[face_dual_cell_id[l_id]], dual_velocity[dual_cell_id]);
+                          dot(M * dual_solution[face_dual_cell_id[l_id]], dual_solution[dual_cell_id]);
                       }
                     }
                   }
@@ -1483,7 +1784,7 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
         //         for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
         //           FaceId face_id      = face_list[i_face];
         //           CellId dual_cell_id = face_dual_cell_id[face_id];
-        //           flux(face_id, 0)    = mes_l[face_id] * dot(value_list[i_face], dual_velocity[dual_cell_id]);   //
+        //           flux(face_id, 0)    = mes_l[face_id] * dot(value_list[i_face], dual_solution[dual_cell_id]);   //
         //           sign
         //         }
         //       }
@@ -1530,14 +1831,18 @@ class VectorDiamondSchemeHandler::VectorDiamondScheme : public VectorDiamondSche
   }
 };
 
-std::tuple<std::shared_ptr<const IDiscreteFunction>,
-           std::shared_ptr<const IDiscreteFunction>,
-           std::shared_ptr<const IDiscreteFunction>>
+std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
 VectorDiamondSchemeHandler::apply() const
 {
   return m_scheme->apply();
 }
 
+std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>>
+EnergyComputerHandler::computeEnergyUpdate() const
+{
+  return m_energy_computer->apply();
+}
+
 std::shared_ptr<const IDiscreteFunction>
 VectorDiamondSchemeHandler::solution() const
 {
@@ -1637,3 +1942,83 @@ VectorDiamondSchemeHandler::VectorDiamondSchemeHandler(
 }
 
 VectorDiamondSchemeHandler::~VectorDiamondSchemeHandler() = default;
+
+EnergyComputerHandler::EnergyComputerHandler(
+  const std::shared_ptr<const IDiscreteFunction>& dual_lambdab,
+  const std::shared_ptr<const IDiscreteFunction>& dual_mub,
+  const std::shared_ptr<const IDiscreteFunction>& U,
+  const std::shared_ptr<const IDiscreteFunction>& source,
+  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
+{
+  const std::shared_ptr i_mesh      = getCommonMesh({U, source});
+  const std::shared_ptr i_dual_mesh = getCommonMesh({dual_lambdab, dual_mub});
+  checkDiscretizationType({dual_lambdab, dual_mub, U, source}, DiscreteFunctionType::P0);
+
+  switch (i_mesh->dimension()) {
+  case 1: {
+    using MeshType                   = Mesh<Connectivity<1>>;
+    using DiscreteScalarFunctionType = DiscreteFunctionP0<1, double>;
+    using DiscreteVectorFunctionType = DiscreteFunctionP0<1, TinyVector<1>>;
+
+    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
+
+    if (DiamondDualMeshManager::instance().getDiamondDualMesh(mesh) != dual_mub->mesh()) {
+      throw NormalError("mu_dual is not defined on the dual mesh");
+    }
+
+    m_energy_computer =
+      std::make_unique<EnergyComputer<1>>(mesh,
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
+                                          bc_descriptor_list);
+    break;
+  }
+  case 2: {
+    using MeshType                   = Mesh<Connectivity<2>>;
+    using DiscreteScalarFunctionType = DiscreteFunctionP0<2, double>;
+    using DiscreteVectorFunctionType = DiscreteFunctionP0<2, TinyVector<2>>;
+
+    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
+
+    if (DiamondDualMeshManager::instance().getDiamondDualMesh(mesh) != dual_mub->mesh()) {
+      throw NormalError("mu_dual is not defined on the dual mesh");
+    }
+
+    m_energy_computer =
+      std::make_unique<EnergyComputer<2>>(mesh,
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
+                                          bc_descriptor_list);
+    break;
+  }
+  case 3: {
+    using MeshType                   = Mesh<Connectivity<3>>;
+    using DiscreteScalarFunctionType = DiscreteFunctionP0<3, double>;
+    using DiscreteVectorFunctionType = DiscreteFunctionP0<3, TinyVector<3>>;
+
+    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
+
+    if (DiamondDualMeshManager::instance().getDiamondDualMesh(mesh) != dual_mub->mesh()) {
+      throw NormalError("mu_dual is not defined on the dual mesh");
+    }
+
+    m_energy_computer =
+      std::make_unique<EnergyComputer<3>>(mesh,
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_lambdab),
+                                          std::dynamic_pointer_cast<const DiscreteScalarFunctionType>(dual_mub),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(U),
+                                          std::dynamic_pointer_cast<const DiscreteVectorFunctionType>(source),
+                                          bc_descriptor_list);
+    break;
+  }
+  default: {
+    throw UnexpectedError("invalid mesh dimension");
+  }
+  }
+}
+
+EnergyComputerHandler::~EnergyComputerHandler() = default;
diff --git a/src/scheme/VectorDiamondScheme.hpp b/src/scheme/VectorDiamondScheme.hpp
index e2c2e0fef484529ea858c98d9715efb6811be3ce..9b56d86dbb3942c2ba51459c1f18394a56913ba8 100644
--- a/src/scheme/VectorDiamondScheme.hpp
+++ b/src/scheme/VectorDiamondScheme.hpp
@@ -25,7 +25,6 @@ class VectorDiamondSchemeHandler
 {
  private:
   class IVectorDiamondScheme;
-
   template <size_t Dimension>
   class VectorDiamondScheme;
 
@@ -39,10 +38,7 @@ class VectorDiamondSchemeHandler
 
   std::shared_ptr<const IDiscreteFunction> dual_solution() const;
 
-  std::tuple<std::shared_ptr<const IDiscreteFunction>,
-             std::shared_ptr<const IDiscreteFunction>,
-             std::shared_ptr<const IDiscreteFunction>>
-  apply() const;
+  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> apply() const;
 
   VectorDiamondSchemeHandler(
     const std::shared_ptr<const IDiscreteFunction>& alphab,
@@ -56,6 +52,32 @@ class VectorDiamondSchemeHandler
   ~VectorDiamondSchemeHandler();
 };
 
+class EnergyComputerHandler
+{
+ private:
+  class IEnergyComputer;
+  template <size_t Dimension>
+  class EnergyComputer;
+
+  template <size_t Dimension>
+  class InterpolationWeightsManager;
+
+ public:
+  std::unique_ptr<IEnergyComputer> m_energy_computer;
+  std::shared_ptr<const IDiscreteFunction> dual_solution() const;
+
+  std::tuple<std::shared_ptr<const IDiscreteFunction>, std::shared_ptr<const IDiscreteFunction>> computeEnergyUpdate()
+    const;
+
+  EnergyComputerHandler(const std::shared_ptr<const IDiscreteFunction>& lambdab,
+                        const std::shared_ptr<const IDiscreteFunction>& mub,
+                        const std::shared_ptr<const IDiscreteFunction>& U,
+                        const std::shared_ptr<const IDiscreteFunction>& source,
+                        const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list);
+
+  ~EnergyComputerHandler();
+};
+
 template <size_t Dimension>
 class LegacyVectorDiamondScheme
 {