From 3eaf95a1445bf397777762f9d01b988f20437acf Mon Sep 17 00:00:00 2001
From: Clovis <clovis.schoeck@etudiant.univ-rennes.fr>
Date: Tue, 16 Jul 2024 16:26:09 +0200
Subject: [PATCH] Wrting a parallel version of the GKS first order in GKS2.cpp

---
 src/language/modules/SchemeModule.cpp |  32 +++
 src/scheme/CMakeLists.txt             |   1 +
 src/scheme/GKS.cpp                    |  34 ++--
 src/scheme/GKS2.cpp                   | 281 ++++++++++++++++++++++++++
 src/scheme/GKS2.hpp                   |  18 ++
 src/scheme/GKSNavier.cpp              | 281 ++++++++++++++++++++++++++
 src/scheme/GKSNavier.hpp              |  18 ++
 7 files changed, 643 insertions(+), 22 deletions(-)
 create mode 100644 src/scheme/GKS2.cpp
 create mode 100644 src/scheme/GKS2.hpp
 create mode 100644 src/scheme/GKSNavier.cpp
 create mode 100644 src/scheme/GKSNavier.hpp

diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index 4b9060790..46c417d61 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -36,6 +36,8 @@
 #include <scheme/FourierBoundaryConditionDescriptor.hpp>
 #include <scheme/FreeBoundaryConditionDescriptor.hpp>
 #include <scheme/GKS.hpp>
+#include <scheme/GKS2.hpp>
+#include <scheme/GKSNavier.hpp>
 #include <scheme/HyperelasticSolver.hpp>
 #include <scheme/IBoundaryConditionDescriptor.hpp>
 #include <scheme/IDiscreteFunctionDescriptor.hpp>
@@ -436,6 +438,36 @@ SchemeModule::SchemeModule()
 
                               ));
 
+  this->_addBuiltinFunction("gks2",
+                            std::function(
+
+                              [](const std::shared_ptr<const DiscreteFunctionVariant>& rho,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_U,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_E,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& tau, const double& delta,
+                                 const double& dt) -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+                                                                 std::shared_ptr<const DiscreteFunctionVariant>,
+                                                                 std::shared_ptr<const DiscreteFunctionVariant>> {
+                                return gks2(rho, rho_U, rho_E, tau, delta, dt);
+                              }
+
+                              ));
+
+  this->_addBuiltinFunction("gksNavier",
+                            std::function(
+
+                              [](const std::shared_ptr<const DiscreteFunctionVariant>& rho,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_U,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_E,
+                                 const std::shared_ptr<const DiscreteFunctionVariant>& tau, const double& delta,
+                                 const double& dt) -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+                                                                 std::shared_ptr<const DiscreteFunctionVariant>,
+                                                                 std::shared_ptr<const DiscreteFunctionVariant>> {
+                                return gksNavier(rho, rho_U, rho_E, tau, delta, dt);
+                              }
+
+                              ));
+
   this->_addBuiltinFunction("glace_solver",
                             std::function(
 
diff --git a/src/scheme/CMakeLists.txt b/src/scheme/CMakeLists.txt
index f3aa4dc1b..019ba8ec4 100644
--- a/src/scheme/CMakeLists.txt
+++ b/src/scheme/CMakeLists.txt
@@ -12,6 +12,7 @@ add_library(
   FluxingAdvectionSolver.cpp
   HyperelasticSolver.cpp
   GKS.cpp
+  GKS2.cpp
   )
 
 target_link_libraries(
diff --git a/src/scheme/GKS.cpp b/src/scheme/GKS.cpp
index fed7b1ae0..dad7b84a7 100644
--- a/src/scheme/GKS.cpp
+++ b/src/scheme/GKS.cpp
@@ -33,8 +33,9 @@ class GKS
       if (tau_n[cell_id] == 0)
         eta[cell_id] = 0;
       else
-        eta[cell_id] = (tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
+        eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
     }
+    // std::cout << "eta = " << eta << std::endl;
 
     DiscreteFunctionP0<const double> rho_n   = rho_v->get<DiscreteFunctionP0<const double>>();
     DiscreteFunctionP0<const Rd> rho_U_n     = rho_U_v->get<DiscreteFunctionP0<const Rd>>();
@@ -44,10 +45,6 @@ class GKS
     DiscreteFunctionP0<Rd> rho_U     = copy(rho_U_n);
     DiscreteFunctionP0<double> rho_E = copy(rho_E_n);
 
-    CellValue<double> lambda{p_mesh->connectivity()};
-    // lambda.fill(0);
-    CellValue<Rd> U{p_mesh->connectivity()};
-    // U.fill(0);
     auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
     auto Vj         = mesh_data.Vj();
 
@@ -73,6 +70,10 @@ class GKS
     rho_node.fill(0);
     rho_U_node.fill(TinyVector<1>(0));
     rho_E_node.fill(0);
+    CellValue<double> lambda{p_mesh->connectivity()};
+    // lambda.fill(0);
+    CellValue<Rd> U{p_mesh->connectivity()};
+    // U.fill(0);
 
     for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       U[cell_id][0]   = rho_U_n[cell_id][0] / rho_n[cell_id];
@@ -100,7 +101,7 @@ class GKS
       rho_E_node[node_list[1]]    = 0.5 * rho_E_cell_left;
     }
 
-    for (CellId cell_id = 1; cell_id < mesh.numberOfCells(); ++cell_id) {
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       double U_2 = U[cell_id][0] * U[cell_id][0];
 
       double rho_cell_right = rho_n[cell_id] * (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0]));
@@ -120,7 +121,7 @@ class GKS
       rho_E_node[node_list[0]] += 0.5 * rho_E_cell_right;
     }
 
-    for (CellId cell_id = 1; cell_id < mesh.numberOfCells(); ++cell_id) {
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       auto node_list      = cell_to_node_matrix[cell_id];
       double rho_U_2_node = rho_U_node[node_list[0]][0] * rho_U_node[node_list[0]][0] / rho_node[node_list[0]];
 
@@ -169,23 +170,13 @@ class GKS
                              (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
                            0.5 * rho_n[cell_id] * (U_2 + 0.5 * (delta + 2) / lambda[cell_id]) *
                              std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);
-
-      // double F3_fn_right = U[cell_id][0] * (rho_E_n[cell_id] + 0.5 * rho_n[cell_id] / lambda[cell_id]) *
-      //                        (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
-      //                      0.5 * (rho_E_n[cell_id] + 0.25 * rho_n[cell_id] / lambda[cell_id]) *
-      //                        std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);
-
       auto node_list = cell_to_node_matrix[cell_id];
 
       rho_U_flux_Navier[node_list[0]][0] += 0.5 * F2_fn_right[0];
       rho_E_flux_Navier[node_list[0]] += 0.5 * F3_fn_right;
     }
-    // std::cout << "lambda " << lambda << std::endl;
-    // std::cout << "rho flux " << rho_flux_Euler << std::endl;
-    // std::cout << "rhoU flux " << rho_U_flux_Euler << std::endl;
-    // std::cout << "rhoE flux " << rho_E_flux_Euler << std::endl;
-    // std::exit(0);
-    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       auto node_list = cell_to_node_matrix[cell_id];
 
       const double rho_flux_Euler_sum   = (rho_flux_Euler[node_list[1]] - rho_flux_Euler[node_list[0]]);
@@ -194,13 +185,12 @@ class GKS
 
       const Rd rho_U_flux_Navier_sum     = (rho_U_flux_Navier[node_list[1]] - rho_U_flux_Navier[node_list[0]]);
       const double rho_E_flux_Navier_sum = (rho_E_flux_Navier[node_list[1]] - rho_E_flux_Navier[node_list[0]]);
-
       rho[cell_id] -= dt / Vj[cell_id] * (rho_flux_Euler_sum);
       rho_U[cell_id][0] -=
-        dt / Vj[cell_id] *
+        0 * dt / Vj[cell_id] *
         (rho_U_flux_Euler_sum[0] + eta[cell_id] * (rho_U_flux_Navier_sum[0] - rho_U_flux_Euler_sum[0]));
       rho_E[cell_id] -=
-        dt / Vj[cell_id] * (rho_E_flux_Euler_sum + 0 * eta[cell_id] * (rho_E_flux_Navier_sum - rho_E_flux_Euler_sum));
+        0 * dt / Vj[cell_id] * (rho_E_flux_Euler_sum + eta[cell_id] * (rho_E_flux_Navier_sum - rho_E_flux_Euler_sum));
     }
     return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
                            std::make_shared<DiscreteFunctionVariant>(rho_U),
diff --git a/src/scheme/GKS2.cpp b/src/scheme/GKS2.cpp
new file mode 100644
index 000000000..54521fc80
--- /dev/null
+++ b/src/scheme/GKS2.cpp
@@ -0,0 +1,281 @@
+
+#include <scheme/GKS.hpp>
+
+#include <mesh/Mesh.hpp>
+#include <mesh/MeshData.hpp>
+#include <mesh/MeshDataManager.hpp>
+#include <mesh/MeshTraits.hpp>
+#include <scheme/DiscreteFunctionUtils.hpp>
+
+template <MeshConcept MeshType>
+class GKS2
+{
+ public:
+  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>>
+  solve(std::shared_ptr<const MeshType> p_mesh,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_v,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
+        std::shared_ptr<const DiscreteFunctionVariant> tau,
+        const double delta,
+        double dt)
+  {
+    using Rd = TinyVector<MeshType::Dimension>;
+
+    const MeshType& mesh = *p_mesh;
+
+    const double pi = std::acos(-1);
+
+    DiscreteFunctionP0<const double> tau_n = tau->get<DiscreteFunctionP0<const double>>();
+    CellValue<double> eta(mesh.connectivity());
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
+      if (tau_n[cell_id] == 0)
+        eta[cell_id] = 0;
+      else
+        eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
+    }
+    // std::cout << "eta = " << eta << std::endl;
+
+    DiscreteFunctionP0<const double> rho_n   = rho_v->get<DiscreteFunctionP0<const double>>();
+    DiscreteFunctionP0<const Rd> rho_U_n     = rho_U_v->get<DiscreteFunctionP0<const Rd>>();
+    DiscreteFunctionP0<const double> rho_E_n = rho_E_v->get<DiscreteFunctionP0<const double>>();
+
+    DiscreteFunctionP0<double> rho   = copy(rho_n);
+    DiscreteFunctionP0<Rd> rho_U     = copy(rho_U_n);
+    DiscreteFunctionP0<double> rho_E = copy(rho_E_n);
+
+    auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+    auto Vj         = mesh_data.Vj();
+
+    auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
+    auto node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();
+
+    const NodeValuePerCell<const Rd> njr = mesh_data.njr();
+
+    NodeValuePerCell<Rd> rho_flux_G(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_U_flux_G(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_E_flux_G(mesh.connectivity());
+    rho_flux_G.fill(TinyVector<1>(0));
+    rho_U_flux_G.fill(TinyVector<1>(0));
+    rho_E_flux_G.fill(TinyVector<1>(0));
+
+    NodeValuePerCell<Rd> rho_flux_Ffn(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_U_flux_Ffn(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_E_flux_Ffn(mesh.connectivity());
+    rho_flux_Ffn.fill(TinyVector<1>(0));
+    rho_U_flux_Ffn.fill(TinyVector<1>(0));
+    rho_E_flux_Ffn.fill(TinyVector<1>(0));
+
+    NodeValue<Rd> F2fn(mesh.connectivity());
+    NodeValue<Rd> F3fn(mesh.connectivity());
+    F2fn.fill(TinyVector<1>(0));
+    F3fn.fill(TinyVector<1>(0));
+
+    NodeValue<double> rho_node(mesh.connectivity());
+    NodeValue<Rd> rho_U_node(mesh.connectivity());
+    NodeValue<double> rho_E_node(mesh.connectivity());
+    rho_node.fill(0);
+    rho_U_node.fill(TinyVector<1>(0));
+    rho_E_node.fill(0);
+    CellValue<double> lambda{p_mesh->connectivity()};
+    // lambda.fill(0);
+    CellValue<Rd> U{p_mesh->connectivity()};
+    // U.fill(TinyVector<1>(0));
+
+    parallel_for(
+      mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
+        U[cell_id][0]   = rho_U_n[cell_id][0] / rho_n[cell_id];
+        double rho_U_2  = rho_U_n[cell_id][0] * U[cell_id][0];
+        lambda[cell_id] = 0.5 * (1. + delta) * rho_n[cell_id] / (2 * rho_E_n[cell_id] - rho_U_2);
+      });
+
+    parallel_for(
+      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
+        const auto& node_cells = node_to_cell_matrix[node_id];
+
+        double rho_cell_left    = 1;
+        Rd rho_U_cell_left      = TinyVector<1>(1);
+        double rho_E_cell_left  = 1;
+        double rho_cell_right   = 1;
+        Rd rho_U_cell_right     = TinyVector<1>(1);
+        double rho_E_cell_right = 1;
+
+        for (size_t l = 0; l < node_cells.size(); l++) {
+          if (node_cells.size() == 1)
+            continue;
+
+          if (l == 0) {
+            double U_2_left = U[node_cells[l]][0] * U[node_cells[l]][0];
+
+            rho_cell_left =
+              rho_n[node_cells[l]] * (1 + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+
+            rho_U_cell_left[0] =
+              rho_U_n[node_cells[l]][0] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+
+            rho_E_cell_left =
+              rho_E_n[node_cells[l]] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+          } else {
+            double U_2_right = U[node_cells[l]][0] * U[node_cells[l]][0];
+
+            rho_cell_right =
+              rho_n[node_cells[l]] * (1 - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+
+            rho_U_cell_right[0] =
+              rho_U_n[node_cells[l]][0] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+
+            rho_E_cell_right =
+              rho_E_n[node_cells[l]] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+          }
+        }
+        rho_node[node_id]   = 0.5 * (rho_cell_left + rho_cell_right);
+        rho_U_node[node_id] = 0.5 * (rho_U_cell_left + rho_U_cell_right);
+        rho_E_node[node_id] = 0.5 * (rho_E_cell_left + rho_E_cell_right);
+      });
+
+    parallel_for(
+      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
+        const auto& node_cells = node_to_cell_matrix[node_id];
+
+        Rd F2_fn_left  = TinyVector<1>(0);
+        Rd F3_fn_left  = TinyVector<1>(0);
+        Rd F2_fn_right = TinyVector<1>(0);
+        Rd F3_fn_right = TinyVector<1>(0);
+
+        for (size_t l = 0; l < node_cells.size(); l++) {
+          if (node_cells.size() == 1)
+            continue;
+
+          if (l == 0) {
+            double U_2_left     = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rho_U_2_left = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_left[0] = (rho_U_2_left + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+                            rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                              std::sqrt(pi * lambda[node_cells[l]]);
+
+            F3_fn_left[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_left + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+                            0.5 * rho_n[node_cells[l]] * (U_2_left + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                              std::exp(-lambda[node_cells[l]] * U_2_left) / std::sqrt(pi * lambda[node_cells[l]]);
+          } else {
+            double U_2_right     = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rho_U_2_right = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_right[0] = (rho_U_2_right + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+                             rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                               std::sqrt(pi * lambda[node_cells[l]]);
+
+            F3_fn_right[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_right + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+                             0.5 * rho_n[node_cells[l]] * (U_2_right + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                               std::exp(-lambda[node_cells[l]] * U_2_right) / std::sqrt(pi * lambda[node_cells[l]]);
+          }
+        }
+
+        F2fn[node_id] = 0.5 * (F2_fn_left + F2_fn_right);
+        F3fn[node_id] = 0.5 * (F3_fn_left + F3_fn_right);
+      });
+
+    parallel_for(
+      mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
+        const auto& cell_nodes = cell_to_node_matrix[cell_id];
+
+        for (size_t r = 0; r < cell_nodes.size(); r++) {
+          double rho_U_2_node = rho_U_node[cell_nodes[r]][0] * rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]];
+
+          rho_flux_G(cell_id, r) = rho_U_node[cell_nodes[r]][0] * njr(cell_id, r);
+          rho_U_flux_G(cell_id, r) =
+            (delta * rho_U_2_node / (1. + delta) + 2 * rho_E_node[cell_nodes[r]] / (1. + delta)) * njr(cell_id, r);
+          rho_E_flux_G(cell_id, r) =
+            rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]] *
+            ((3. + delta) * rho_E_node[cell_nodes[r]] / (1. + delta) - rho_U_2_node / (1. + delta)) * njr(cell_id, r);
+
+          rho_U_flux_Ffn(cell_id, r) = F2fn[cell_nodes[r]][0] * njr(cell_id, r);
+          rho_E_flux_Ffn(cell_id, r) = F3fn[cell_nodes[r]][0] * njr(cell_id, r);
+        }
+      });
+
+    // // for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
+    // // std::cout << "flux_E_Ffn : " << rho_E_flux_Ffn(cell_id, 0) << std::endl;
+    // // }
+
+    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+      const size_t& nb_nodes = rho_flux_G.numberOfSubValues(cell_id);
+      for (size_t r = 0; r < nb_nodes; r++) {
+        rho[cell_id] -= dt / Vj[cell_id] * rho_flux_G(cell_id, r)[0];
+        // rho_U[cell_id] -= dt / Vj[cell_id] *
+        //                   ((1 - eta[cell_id]) * rho_U_flux_G(cell_id, r) + eta[cell_id] *
+        // (rho_U_flux_Ffn(cell_id,
+        //                   r)));
+        // rho_E[cell_id] -=
+        //   dt / Vj[cell_id] *
+        //   ((1 - eta[cell_id]) * rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0]));
+        rho_U[cell_id] -=
+          dt / Vj[cell_id] *
+          (rho_U_flux_G(cell_id, r) + eta[cell_id] * (rho_U_flux_Ffn(cell_id, r) - rho_U_flux_G(cell_id, r)));
+        rho_E[cell_id] -=
+          dt / Vj[cell_id] *
+          (rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0] - rho_E_flux_G(cell_id, r)[0]));
+      }
+    }
+
+    return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
+                           std::make_shared<DiscreteFunctionVariant>(rho_U),
+                           std::make_shared<DiscreteFunctionVariant>(rho_E));
+  }
+
+  GKS2() = default;
+};
+
+std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+           std::shared_ptr<const DiscreteFunctionVariant>,   // U
+           std::shared_ptr<const DiscreteFunctionVariant>>   // E
+gks2(std::shared_ptr<const DiscreteFunctionVariant> rho_v,
+     std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
+     std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
+     std::shared_ptr<const DiscreteFunctionVariant> tau,
+     const double delta,
+     const double dt)
+{
+  std::shared_ptr mesh_v = getCommonMesh({rho_v, rho_U_v, rho_E_v});
+  if (not mesh_v) {
+    throw NormalError("discrete functions are not defined on the same mesh");
+  }
+
+  if (not checkDiscretizationType({rho_v, rho_U_v, rho_E_v}, DiscreteFunctionType::P0)) {
+    throw NormalError("GKS solver expects P0 functions");
+  }
+
+  return std::visit(
+    [&](auto&& p_mesh)
+      -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>, std::shared_ptr<const DiscreteFunctionVariant>,
+                    std::shared_ptr<const DiscreteFunctionVariant>> {
+      using MeshType = std::decay_t<decltype(*p_mesh)>;
+      if constexpr (is_polygonal_mesh_v<MeshType>) {
+        if constexpr (MeshType::Dimension == 1) {
+          GKS2<MeshType> gks2;
+          return gks2.solve(p_mesh, rho_v, rho_U_v, rho_E_v, tau, delta, dt);
+
+        } else {
+          throw NormalError("dimension not treated");
+        }
+
+      } else {
+        throw NormalError("unexpected mesh type");
+      }
+    },
+    mesh_v->variant());
+}
diff --git a/src/scheme/GKS2.hpp b/src/scheme/GKS2.hpp
new file mode 100644
index 000000000..8e4b25f95
--- /dev/null
+++ b/src/scheme/GKS2.hpp
@@ -0,0 +1,18 @@
+#ifndef GKS2_HPP
+#define GKS2_HPP
+
+#include <scheme/DiscreteFunctionVariant.hpp>
+
+#include <tuple>
+
+std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+           std::shared_ptr<const DiscreteFunctionVariant>,   // rhoU
+           std::shared_ptr<const DiscreteFunctionVariant>>   // rhoE
+gks2(std::shared_ptr<const DiscreteFunctionVariant> rho,
+     std::shared_ptr<const DiscreteFunctionVariant> rhoU,
+     std::shared_ptr<const DiscreteFunctionVariant> rhoE,
+     std::shared_ptr<const DiscreteFunctionVariant> tau,
+     const double delta,
+     const double dt);
+
+#endif   // GKS2_HPP
diff --git a/src/scheme/GKSNavier.cpp b/src/scheme/GKSNavier.cpp
new file mode 100644
index 000000000..92eeff2c2
--- /dev/null
+++ b/src/scheme/GKSNavier.cpp
@@ -0,0 +1,281 @@
+
+#include <scheme/GKS.hpp>
+
+#include <mesh/Mesh.hpp>
+#include <mesh/MeshData.hpp>
+#include <mesh/MeshDataManager.hpp>
+#include <mesh/MeshTraits.hpp>
+#include <scheme/DiscreteFunctionUtils.hpp>
+
+template <MeshConcept MeshType>
+class GKSNAVIER
+{
+ public:
+  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>>
+  solve(std::shared_ptr<const MeshType> p_mesh,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_v,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
+        std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
+        std::shared_ptr<const DiscreteFunctionVariant> tau,
+        const double delta,
+        double dt)
+  {
+    using Rd = TinyVector<MeshType::Dimension>;
+
+    const MeshType& mesh = *p_mesh;
+
+    const double pi = std::acos(-1);
+
+    DiscreteFunctionP0<const double> tau_n = tau->get<DiscreteFunctionP0<const double>>();
+    CellValue<double> eta(mesh.connectivity());
+    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
+      if (tau_n[cell_id] == 0)
+        eta[cell_id] = 0;
+      else
+        eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
+    }
+    // std::cout << "eta = " << eta << std::endl;
+
+    DiscreteFunctionP0<const double> rho_n   = rho_v->get<DiscreteFunctionP0<const double>>();
+    DiscreteFunctionP0<const Rd> rho_U_n     = rho_U_v->get<DiscreteFunctionP0<const Rd>>();
+    DiscreteFunctionP0<const double> rho_E_n = rho_E_v->get<DiscreteFunctionP0<const double>>();
+
+    DiscreteFunctionP0<double> rho   = copy(rho_n);
+    DiscreteFunctionP0<Rd> rho_U     = copy(rho_U_n);
+    DiscreteFunctionP0<double> rho_E = copy(rho_E_n);
+
+    auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+    auto Vj         = mesh_data.Vj();
+
+    auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
+    auto node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();
+
+    const NodeValuePerCell<const Rd> njr = mesh_data.njr();
+
+    NodeValuePerCell<Rd> rho_flux_G(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_U_flux_G(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_E_flux_G(mesh.connectivity());
+    rho_flux_G.fill(TinyVector<1>(0));
+    rho_U_flux_G.fill(TinyVector<1>(0));
+    rho_E_flux_G.fill(TinyVector<1>(0));
+
+    NodeValuePerCell<Rd> rho_flux_Ffn(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_U_flux_Ffn(mesh.connectivity());
+    NodeValuePerCell<Rd> rho_E_flux_Ffn(mesh.connectivity());
+    rho_flux_Ffn.fill(TinyVector<1>(0));
+    rho_U_flux_Ffn.fill(TinyVector<1>(0));
+    rho_E_flux_Ffn.fill(TinyVector<1>(0));
+
+    NodeValue<Rd> F2fn(mesh.connectivity());
+    NodeValue<Rd> F3fn(mesh.connectivity());
+    F2fn.fill(TinyVector<1>(0));
+    F3fn.fill(TinyVector<1>(0));
+
+    NodeValue<double> rho_node(mesh.connectivity());
+    NodeValue<Rd> rho_U_node(mesh.connectivity());
+    NodeValue<double> rho_E_node(mesh.connectivity());
+    rho_node.fill(0);
+    rho_U_node.fill(TinyVector<1>(0));
+    rho_E_node.fill(0);
+    CellValue<double> lambda{p_mesh->connectivity()};
+    // lambda.fill(0);
+    CellValue<Rd> U{p_mesh->connectivity()};
+    // U.fill(TinyVector<1>(0));
+
+    parallel_for(
+      mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
+        U[cell_id][0]   = rho_U_n[cell_id][0] / rho_n[cell_id];
+        double rho_U_2  = rho_U_n[cell_id][0] * U[cell_id][0];
+        lambda[cell_id] = 0.5 * (1. + delta) * rho_n[cell_id] / (2 * rho_E_n[cell_id] - rho_U_2);
+      });
+
+    parallel_for(
+      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
+        const auto& node_cells = node_to_cell_matrix[node_id];
+
+        double rho_cell_left    = 1;
+        Rd rho_U_cell_left      = TinyVector<1>(1);
+        double rho_E_cell_left  = 1;
+        double rho_cell_right   = 1;
+        Rd rho_U_cell_right     = TinyVector<1>(1);
+        double rho_E_cell_right = 1;
+
+        for (size_t l = 0; l < node_cells.size(); l++) {
+          if (node_cells.size() == 1)
+            continue;
+
+          if (l == 0) {
+            double U_2_left = U[node_cells[l]][0] * U[node_cells[l]][0];
+
+            rho_cell_left =
+              rho_n[node_cells[l]] * (1 + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+
+            rho_U_cell_left[0] =
+              rho_U_n[node_cells[l]][0] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+
+            rho_E_cell_left =
+              rho_E_n[node_cells[l]] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+          } else {
+            double U_2_right = U[node_cells[l]][0] * U[node_cells[l]][0];
+
+            rho_cell_right =
+              rho_n[node_cells[l]] * (1 - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+
+            rho_U_cell_right[0] =
+              rho_U_n[node_cells[l]][0] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+
+            rho_E_cell_right =
+              rho_E_n[node_cells[l]] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                std::sqrt(pi * lambda[node_cells[l]]);
+          }
+        }
+        rho_node[node_id]   = 0.5 * (rho_cell_left + rho_cell_right);
+        rho_U_node[node_id] = 0.5 * (rho_U_cell_left + rho_U_cell_right);
+        rho_E_node[node_id] = 0.5 * (rho_E_cell_left + rho_E_cell_right);
+      });
+
+    parallel_for(
+      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
+        const auto& node_cells = node_to_cell_matrix[node_id];
+
+        Rd F2_fn_left  = TinyVector<1>(0);
+        Rd F3_fn_left  = TinyVector<1>(0);
+        Rd F2_fn_right = TinyVector<1>(0);
+        Rd F3_fn_right = TinyVector<1>(0);
+
+        for (size_t l = 0; l < node_cells.size(); l++) {
+          if (node_cells.size() == 1)
+            continue;
+
+          if (l == 0) {
+            double U_2_left     = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rho_U_2_left = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_left[0] = (rho_U_2_left + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+                            rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
+                              std::sqrt(pi * lambda[node_cells[l]]);
+
+            F3_fn_left[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_left + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
+                            0.5 * rho_n[node_cells[l]] * (U_2_left + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                              std::exp(-lambda[node_cells[l]] * U_2_left) / std::sqrt(pi * lambda[node_cells[l]]);
+          } else {
+            double U_2_right     = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rho_U_2_right = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_right[0] = (rho_U_2_right + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+                             rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
+                               std::sqrt(pi * lambda[node_cells[l]]);
+
+            F3_fn_right[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_right + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
+                             0.5 * rho_n[node_cells[l]] * (U_2_right + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                               std::exp(-lambda[node_cells[l]] * U_2_right) / std::sqrt(pi * lambda[node_cells[l]]);
+          }
+        }
+
+        F2fn[node_id] = 0.5 * (F2_fn_left + F2_fn_right);
+        F3fn[node_id] = 0.5 * (F3_fn_left + F3_fn_right);
+      });
+
+    parallel_for(
+      mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
+        const auto& cell_nodes = cell_to_node_matrix[cell_id];
+
+        for (size_t r = 0; r < cell_nodes.size(); r++) {
+          double rho_U_2_node = rho_U_node[cell_nodes[r]][0] * rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]];
+
+          rho_flux_G(cell_id, r) = rho_U_node[cell_nodes[r]][0] * njr(cell_id, r);
+          rho_U_flux_G(cell_id, r) =
+            (delta * rho_U_2_node / (1. + delta) + 2 * rho_E_node[cell_nodes[r]] / (1. + delta)) * njr(cell_id, r);
+          rho_E_flux_G(cell_id, r) =
+            rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]] *
+            ((3. + delta) * rho_E_node[cell_nodes[r]] / (1. + delta) - rho_U_2_node / (1. + delta)) * njr(cell_id, r);
+
+          rho_U_flux_Ffn(cell_id, r) = F2fn[cell_nodes[r]][0] * njr(cell_id, r);
+          rho_E_flux_Ffn(cell_id, r) = F3fn[cell_nodes[r]][0] * njr(cell_id, r);
+        }
+      });
+
+    // // for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
+    // // std::cout << "flux_E_Ffn : " << rho_E_flux_Ffn(cell_id, 0) << std::endl;
+    // // }
+
+    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+      const size_t& nb_nodes = rho_flux_G.numberOfSubValues(cell_id);
+      for (size_t r = 0; r < nb_nodes; r++) {
+        rho[cell_id] -= dt / Vj[cell_id] * rho_flux_G(cell_id, r)[0];
+        // rho_U[cell_id] -= dt / Vj[cell_id] *
+        //                   ((1 - eta[cell_id]) * rho_U_flux_G(cell_id, r) + eta[cell_id] *
+        // (rho_U_flux_Ffn(cell_id,
+        //                   r)));
+        // rho_E[cell_id] -=
+        //   dt / Vj[cell_id] *
+        //   ((1 - eta[cell_id]) * rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0]));
+        rho_U[cell_id] -=
+          dt / Vj[cell_id] *
+          (rho_U_flux_G(cell_id, r) + eta[cell_id] * (rho_U_flux_Ffn(cell_id, r) - rho_U_flux_G(cell_id, r)));
+        rho_E[cell_id] -=
+          dt / Vj[cell_id] *
+          (rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0] - rho_E_flux_G(cell_id, r)[0]));
+      }
+    }
+
+    return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
+                           std::make_shared<DiscreteFunctionVariant>(rho_U),
+                           std::make_shared<DiscreteFunctionVariant>(rho_E));
+  }
+
+  GKSNAVIER() = default;
+};
+
+std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+           std::shared_ptr<const DiscreteFunctionVariant>,   // U
+           std::shared_ptr<const DiscreteFunctionVariant>>   // E
+gks2(std::shared_ptr<const DiscreteFunctionVariant> rho_v,
+     std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
+     std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
+     std::shared_ptr<const DiscreteFunctionVariant> tau,
+     const double delta,
+     const double dt)
+{
+  std::shared_ptr mesh_v = getCommonMesh({rho_v, rho_U_v, rho_E_v});
+  if (not mesh_v) {
+    throw NormalError("discrete functions are not defined on the same mesh");
+  }
+
+  if (not checkDiscretizationType({rho_v, rho_U_v, rho_E_v}, DiscreteFunctionType::P0)) {
+    throw NormalError("GKS solver expects P0 functions");
+  }
+
+  return std::visit(
+    [&](auto&& p_mesh)
+      -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>, std::shared_ptr<const DiscreteFunctionVariant>,
+                    std::shared_ptr<const DiscreteFunctionVariant>> {
+      using MeshType = std::decay_t<decltype(*p_mesh)>;
+      if constexpr (is_polygonal_mesh_v<MeshType>) {
+        if constexpr (MeshType::Dimension == 1) {
+          GKSNAVIER<MeshType> gksNavier;
+          return gksNavier.solve(p_mesh, rho_v, rho_U_v, rho_E_v, tau, delta, dt);
+
+        } else {
+          throw NormalError("dimension not treated");
+        }
+
+      } else {
+        throw NormalError("unexpected mesh type");
+      }
+    },
+    mesh_v->variant());
+}
diff --git a/src/scheme/GKSNavier.hpp b/src/scheme/GKSNavier.hpp
new file mode 100644
index 000000000..9f6135f1c
--- /dev/null
+++ b/src/scheme/GKSNavier.hpp
@@ -0,0 +1,18 @@
+#ifndef GKSNAVIER_HPP
+#define GKSNAVIER_HPP
+
+#include <scheme/DiscreteFunctionVariant.hpp>
+
+#include <tuple>
+
+std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+           std::shared_ptr<const DiscreteFunctionVariant>,   // rhoU
+           std::shared_ptr<const DiscreteFunctionVariant>>   // rhoE
+gksNavier(std::shared_ptr<const DiscreteFunctionVariant> rho,
+          std::shared_ptr<const DiscreteFunctionVariant> rhoU,
+          std::shared_ptr<const DiscreteFunctionVariant> rhoE,
+          std::shared_ptr<const DiscreteFunctionVariant> tau,
+          const double delta,
+          const double dt);
+
+#endif   // GKSNAVIER_HPP
-- 
GitLab