diff --git a/src/scheme/GKSNavier.cpp b/src/scheme/GKSNavier.cpp
index 1bfbdace68b597bdc777d87fb1e345083181ff99..29c4d8cf0d49025771b92c440d551283c062b351 100644
--- a/src/scheme/GKSNavier.cpp
+++ b/src/scheme/GKSNavier.cpp
@@ -31,7 +31,8 @@ gks_inv_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
 
         CellValue<double> local_inv_dt{mesh.connectivity()};
         parallel_for(
-          mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { local_inv_dt[j] = (c[j] + abs(U[j])) / Vj[j]; });
+          mesh.numberOfCells(),
+          PUGS_LAMBDA(CellId cell_id) { local_inv_dt[cell_id] = (c[cell_id] + abs(U[cell_id])) / Vj[cell_id]; });
 
         return max(local_inv_dt);
       } else {
@@ -55,7 +56,217 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
   using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
   using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;
 
-  const double pi = std::acos(-1);
+  const double pi    = std::acos(-1);
+  const double sqrt2 = std::sqrt(2);
+
+  TinyMatrix<2 + Dimension>
+  _computeInvMatrixM(const double& rho, const double& U, const double& lambda, const double& delta) const
+  {
+    const double sqrt_rho    = std::sqrt(rho);
+    const double sqrt_lambda = std::sqrt(lambda);
+    const double U_2         = U * U;
+
+    TinyMatrix<2 + Dimension> invMatrixM(zero);
+    TinyMatrix<2 + Dimension> MatrixP(zero);
+    TinyMatrix<2 + Dimension> transposeMatrixP(zero);
+    TinyMatrix<2 + Dimension> invMatrixMtilde(zero);
+
+    invMatrixMtilde(1, 1) = 1;
+    invMatrixMtilde(2, 2) = 1;
+    invMatrixMtilde(3, 3) = 2 * sqrt2 * lambda / (std::sqrt(delta + 1) * sqrt_rho);
+
+    MatrixP(1, 1) = 1 / sqrt_rho;
+    MatrixP(2, 2) = sqrt2 * sqrt_lambda / sqrt_rho;
+    MatrixP(3, 3) = 1;
+    MatrixP(2, 1) = -sqrt2 * sqrt_lambda * U / sqrt_rho;
+    MatrixP(3, 1) = 0.5 * U_2 - 0.25 * (delta + 1) / lambda;
+    MatrixP(3, 2) = -U;
+
+    invMatrixM = MatrixP * invMatrixMtilde * transpose(MatrixP);
+
+    return invMatrixM;
+  }
+
+  TinyVector<2 + Dimension>
+  _compute_al(const TinyMatrix<2 + Dimension>& invMatrixM,
+              const DiscreteScalarFunction& rho,
+              const DiscreteVectorFunction& rhoU,
+              const DiscreteScalarFunction& rhoE)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_ar(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_Al(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_Ar(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_bl(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_br(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<2 + Dimension>
+  _compute_B(const TinyMatrix<2 + Dimension>& invMatrixM)
+  {}
+
+  TinyVector<7>
+  _compute_u_moments(const double& rho, const Rd& U, const double& lambda)
+  {
+    const double U_2      = U[0] * U[0];
+    const double U_4      = U_2 * U_2;
+    const double lambda_2 = lambda * lambda;
+
+    TinyVector<7> u_moments;
+    u_moments[0] = rho;
+    u_moments[1] = rho * U[0];
+    u_moments[2] = rho * (U_2 + 0.5 / lambda);
+    u_moments[3] = rho * U[0] * (U_2 + 1.5 / lambda);
+    u_moments[4] = rho * (U_4 + 3 * U_2 / lambda + 0.75 / lambda_2);
+    u_moments[5] = rho * U[0] * (U_4 + 5 * U_2 / lambda + 3.75 / lambda_2);
+    u_moments[6] = rho * (U_4 * U_2 + 7.5 * U_4 / lambda + 11.25 * U_2 / lambda_2 + 1.875 / (lambda_2 * lambda));
+
+    return u_moments;
+  }
+  TinyVector<2>
+  _compute_xi2_moments(const double& lambda, const double& delta)
+  {
+    TinyVector<2> xi2_moments;
+    xi2_moments[0] = delta * 0.5 / lambda;
+    xi2_moments[1] = 0.25 * (delta * delta + 2 * delta) / (lambda * lambda);
+
+    return xi2_moments;
+  }
+
+  TinyMatrix<2 + Dimension>
+  _computeMatrixC1(const double& rho,
+                   const Rd& U,
+                   const double& lambda,
+                   TinyVector<7> u_moments,
+                   TinyVector<2> xi2_moments)
+  {
+    const double U_2      = U[0] * U[0];
+    const double U_4      = U_2 * U_2;
+    const double lambda_2 = lambda * lambda;
+
+    const double erf_term = (1 + std::erf(std::sqrt(lambda) * U[0]));
+    const double exp_term = std::exp(-lambda * U_2) / std::sqrt(pi * lambda);
+
+    TinyMatrix<2 + Dimension> MatrixC1(zero);
+
+    MatrixC1(1, 1) = u_moments[2] * erf_term + rho * U * exp_term;
+    MatrixC1(2, 2) = u_moments[4] * erf_term + rho * U * (U_2 + 2.5 / lambda) * exp_term;
+    MatrixC1(3, 3) =
+      (u_moments[6] + 2 * xi2_moments[0] * u_moments[4] + xi2_moments[1] * u_moments[2]) * erf_term +
+      rho * (U * (U_4 + 7 * U_2 / lambda + 8.25 / lambda_2) + 2 * xi2_moments[0] * U * (U_2 + 2.5 / lambda)) * exp_term;
+    MatrixC1(3, 3) *= 0.25;
+
+    MatrixC1(1, 1) *= 0.25;
+    MatrixC1(2, 2) *= 0.25;
+    MatrixC1(3, 3) *= 0.25;
+
+    MatrixC1(2, 1) = u_moments[3] + rho * (U_2 + 1 / lambda) * exp_term;
+    MatrixC1(2, 1) *= 0.5;
+    MatrixC1(3, 1) = 0.5 * (u_moments[4] + u_moments[2] * xi2_moments[1]) * erf_term +
+                     rho * U[0] * ((U_2 + 2.5 / lambda) + xi2_moments[0]) * exp_term;
+    MatrixC1(3, 1) *= 0.5;
+    MatrixC1(3, 2) = (u_moments[5] + u_moments[3] * xi2_moments[0]) * erf_term +
+                     rho * (U_4 + 4.5 * U_2 / lambda + 2 / lambda_2 + (U_2 + 1 / lambda) * xi2_moments[0]) * exp_term;
+    MatrixC1(3, 2) *= 0.5;
+
+    return (MatrixC1 + transpose(MatrixC1));
+  }
+
+  TinyMatrix<2 + Dimension>
+  _computeMatrixC2(const double& rho,
+                   const Rd& U,
+                   const double& lambda,
+                   TinyVector<7> u_moments,
+                   TinyVector<2> xi2_moments)
+  {
+    const double U_2      = U[0] * U[0];
+    const double U_4      = U_2 * U_2;
+    const double lambda_2 = lambda * lambda;
+
+    const double erf_term = (1 - std::erf(std::sqrt(lambda) * U[0]));
+    const double exp_term = std::exp(-lambda * U_2) / std::sqrt(pi * lambda);
+
+    TinyMatrix<2 + Dimension> MatrixC2(zero);
+
+    MatrixC2(1, 1) = u_moments[2] * erf_term - rho * U * exp_term;
+    MatrixC2(2, 2) = u_moments[4] * erf_term - rho * U * (U_2 + 2.5 / lambda) * exp_term;
+    MatrixC2(3, 3) =
+      (u_moments[6] + 2 * xi2_moments[0] * u_moments[4] + xi2_moments[1] * u_moments[2]) * erf_term -
+      rho * (U * (U_4 + 7 * U_2 / lambda + 8.25 / lambda_2) + 2 * xi2_moments[0] * U * (U_2 + 2.5 / lambda)) * exp_term;
+    MatrixC2(3, 3) *= 0.25;
+
+    MatrixC2(1, 1) *= 0.25;
+    MatrixC2(2, 2) *= 0.25;
+    MatrixC2(3, 3) *= 0.25;
+
+    MatrixC2(2, 1) = u_moments[3] * erf_term - rho * (U_2 + 1 / lambda) * exp_term;
+    MatrixC2(2, 1) *= 0.5;
+    MatrixC2(3, 1) = 0.5 * (u_moments[4] + u_moments[2] * xi2_moments[1]) * erf_term -
+                     rho * U[0] * ((U_2 + 2.5 / lambda) + xi2_moments[0]) * exp_term;
+    MatrixC2(3, 1) *= 0.5;
+    MatrixC2(3, 2) = (u_moments[5] + u_moments[3] * xi2_moments[0]) * erf_term -
+                     rho * (U_4 + 4.5 * U_2 / lambda + 2 / lambda_2 + (U_2 + 1 / lambda) * xi2_moments[0]) * exp_term;
+    MatrixC2(3, 2) *= 0.5;
+
+    return (MatrixC2 + transpose(MatrixC2));
+  }
+
+  TinyMatrix<2 + Dimension>
+  _computeMatrixC3(const double& rho,
+                   const Rd& U,
+                   const double& lambda,
+                   TinyVector<7> u_moments,
+                   TinyVector<2> xi2_moments)
+  {
+    const double U_2      = U[0] * U[0];
+    const double U_4      = U_2 * U_2;
+    const double lambda_2 = lambda * lambda;
+
+    const double erf_term = (1 - std::erf(std::sqrt(lambda) * U[0]));
+    const double exp_term = std::exp(-lambda * U_2) / std::sqrt(pi * lambda);
+
+    TinyMatrix<2 + Dimension> MatrixC3(zero);
+
+    MatrixC3(1, 1) = u_moments[2] * erf_term - rho * U * exp_term;
+    MatrixC3(2, 2) = u_moments[4] * erf_term - rho * U * (U_2 + 2.5 / lambda) * exp_term;
+    MatrixC3(3, 3) =
+      (u_moments[6] + 2 * xi2_moments[0] * u_moments[4] + xi2_moments[1] * u_moments[2]) * erf_term -
+      rho * (U * (U_4 + 7 * U_2 / lambda + 8.25 / lambda_2) + 2 * xi2_moments[0] * U * (U_2 + 2.5 / lambda)) * exp_term;
+    MatrixC3(3, 3) *= 0.25;
+
+    MatrixC3(1, 1) *= 0.5;
+    MatrixC3(2, 2) *= 0.5;
+    MatrixC3(3, 3) *= 0.5;
+
+    MatrixC3(2, 1) = u_moments[3] * erf_term - rho * (U_2 + 1 / lambda) * exp_term;
+    MatrixC3(2, 1) *= 0.5;
+    MatrixC3(3, 1) = 0.5 * (u_moments[4] + u_moments[2] * xi2_moments[1]) * erf_term -
+                     rho * U[0] * ((U_2 + 2.5 / lambda) + xi2_moments[0]) * exp_term;
+    MatrixC3(3, 1) *= 0.5;
+    MatrixC3(3, 2) = (u_moments[5] + u_moments[3] * xi2_moments[0]) * erf_term -
+                     rho * (U_4 + 4.5 * U_2 / lambda + 2 / lambda_2 + (U_2 + 1 / lambda) * xi2_moments[0]) * exp_term;
+    MatrixC3(3, 2) *= 0.5;
+
+    // Divide by 2 the diagonal because there are double terms on it after the addition of the matrix and its transpose
+    MatrixC3(1, 1) *= 0.5;
+    MatrixC3(2, 2) *= 0.5;
+    MatrixC3(3, 3) *= 0.5;
+
+    return (MatrixC3 + transpose(MatrixC3));
+  }
 
  public:
   std::tuple<const std::shared_ptr<const ItemValueVariant>,
@@ -65,7 +276,8 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
                  const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
                  const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
                  const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
-                 const double& delta) const
+                 const double& delta,
+                 const double& dt) const
   {
     auto mesh_v          = getCommonMesh({rho_v, rhoU_v, rhoE_v});
     const MeshType& mesh = *mesh_v->get<MeshType>();
@@ -76,7 +288,7 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
       if (tau_n[cell_id] == 0)
         eta[cell_id] = 0;
       else
-        eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
+        eta[cell_id] = (tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
     }
 
     DiscreteScalarFunction rho_n  = rho_v->get<DiscreteScalarFunction>();
@@ -97,39 +309,39 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
     NodeValuePerCell<Rd> rho_flux_G(mesh.connectivity());
     NodeValuePerCell<Rd> rhoU_flux_G(mesh.connectivity());
     NodeValuePerCell<Rd> rhoE_flux_G(mesh.connectivity());
-    rho_flux_G.fill(Rd());
-    rhoU_flux_G.fill(Rd());
-    rhoE_flux_G.fill(Rd());
+    rho_flux_G.fill(zero);
+    rhoU_flux_G.fill(zero);
+    rhoE_flux_G.fill(zero);
 
     NodeValuePerCell<Rd> rho_flux_Ffn(mesh.connectivity());
     NodeValuePerCell<Rd> rhoU_flux_Ffn(mesh.connectivity());
     NodeValuePerCell<Rd> rhoE_flux_Ffn(mesh.connectivity());
-    rho_flux_Ffn.fill(Rd());
-    rhoU_flux_Ffn.fill(Rd());
-    rhoE_flux_Ffn.fill(Rd());
+    rho_flux_Ffn.fill(zero);
+    rhoU_flux_Ffn.fill(zero);
+    rhoE_flux_Ffn.fill(zero);
 
     CellValue<double> rho_fluxes(mesh.connectivity());
     CellValue<Rd> rhoU_fluxes(mesh.connectivity());
     CellValue<double> rhoE_fluxes(mesh.connectivity());
     rho_fluxes.fill(0);
-    rhoU_fluxes.fill(Rd());
+    rhoU_fluxes.fill(zero);
     rhoE_fluxes.fill(0);
 
     NodeValue<Rd> F2fn(mesh.connectivity());
     NodeValue<Rd> F3fn(mesh.connectivity());
-    F2fn.fill(Rd());
-    F3fn.fill(Rd());
+    F2fn.fill(zero);
+    F3fn.fill(zero);
 
     NodeValue<double> rho_node(mesh.connectivity());
     NodeValue<Rd> rhoU_node(mesh.connectivity());
     NodeValue<double> rhoE_node(mesh.connectivity());
     rho_node.fill(0);
-    rhoU_node.fill(Rd());
+    rhoU_node.fill(zero);
     rhoE_node.fill(0);
     CellValue<double> lambda{mesh.connectivity()};
-    lambda.fill(0);
+    lambda.fill(1);
     CellValue<Rd> U{mesh.connectivity()};
-    U.fill(Rd());
+    U.fill(zero);
 
     CellValue<double> err_function_term(mesh.connectivity());
     CellValue<double> exp_term(mesh.connectivity());
@@ -142,19 +354,10 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
       });
 
     parallel_for(
-      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
-        const auto& node_cells = node_to_cell_matrix[node_id];
-        for (size_t l = 0; l < node_cells.size(); l++) {
-          if (node_cells.size() == 1)
-            continue;
-
-          if (l == 0) {
-            double U_2_left                  = U[node_cells[l]][0] * U[node_cells[l]][0];
-            err_function_term[node_cells[l]] = std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]);
-            exp_term[node_cells[l]] =
-              std::exp(-lambda[node_cells[l]] * U_2_left) / std::sqrt(pi * lambda[node_cells[l]]);
-          }
-        }
+      mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
+        double U_2                 = U[cell_id][0] * U[cell_id][0];
+        err_function_term[cell_id] = std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0]);
+        exp_term[cell_id]          = std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);
       });
 
     parallel_for(
@@ -258,7 +461,7 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
         }
       });
 
-    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       const size_t& nb_nodes = rho_flux_G.numberOfSubValues(cell_id);
       for (size_t r = 0; r < nb_nodes; r++) {
         rho_fluxes[cell_id] += rho_flux_G(cell_id, r)[0];
@@ -297,7 +500,7 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
     auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
     auto Vj         = mesh_data.Vj();
 
-    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+    for (CellId cell_id = 2; cell_id < mesh.numberOfCells() - 2; ++cell_id) {
       new_rho[cell_id] -= dt / Vj[cell_id] * rho_fluxes[cell_id];
       new_rhoU[cell_id] -= dt / Vj[cell_id] * rhoU_fluxes[cell_id];
       new_rhoE[cell_id] -= dt / Vj[cell_id] * rhoE_fluxes[cell_id];
@@ -339,8 +542,8 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
   }
 
   std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
-             std::shared_ptr<const DiscreteFunctionVariant>,   // U
-             std::shared_ptr<const DiscreteFunctionVariant>>   // E
+             std::shared_ptr<const DiscreteFunctionVariant>,   // rho U
+             std::shared_ptr<const DiscreteFunctionVariant>>   // rho E
   gksNavier(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
             const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
             const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
@@ -350,7 +553,7 @@ class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
   {
     std::shared_ptr mesh_v = getCommonMesh({rho_v, rhoU_v, rhoE_v});
 
-    auto [rho_fluxes, rhoU_fluxes, rhoE_fluxes] = compute_fluxes(rho_v, rhoU_v, rhoE_v, tau_v, delta);
+    auto [rho_fluxes, rhoU_fluxes, rhoE_fluxes] = compute_fluxes(rho_v, rhoU_v, rhoE_v, tau_v, delta, dt);
     return apply_fluxes(rho_v, rhoU_v, rhoE_v, rho_fluxes, rhoU_fluxes, rhoE_fluxes, dt);
   }
 
diff --git a/src/scheme/GKSNavier.hpp b/src/scheme/GKSNavier.hpp
index 06103195a1582116b467bacf4978bd8f4fd51122..571f87d8391e9a46c2eece2d95281659122446c3 100644
--- a/src/scheme/GKSNavier.hpp
+++ b/src/scheme/GKSNavier.hpp
@@ -36,7 +36,8 @@ class GKSHandler
                    const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
                    const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
                    const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
-                   const double& delta) const = 0;
+                   const double& delta,
+                   const double& dt) const = 0;
 
     virtual std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
                        std::shared_ptr<const DiscreteFunctionVariant>,