From aa67362deac78c798cd933ccbf7ad5b125986b91 Mon Sep 17 00:00:00 2001
From: Clovis <clovis.schoeck@etudiant.univ-rennes.fr>
Date: Fri, 9 Aug 2024 16:42:22 +0200
Subject: [PATCH] Add base of GKS to solve Navier Stokes, structure inspired by
 AcousticSolver - Clovis Schoeck

---
 src/language/modules/SchemeModule.cpp |  33 +-
 src/scheme/GKS.cpp                    |  16 +-
 src/scheme/GKS2.cpp                   |  32 +-
 src/scheme/GKSNavier.cpp              | 442 ++++++++++++++++----------
 src/scheme/GKSNavier.hpp              |  84 ++++-
 5 files changed, 390 insertions(+), 217 deletions(-)

diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index 46c417d61..342a440db 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -453,20 +453,20 @@ SchemeModule::SchemeModule()
 
                               ));
 
-  this->_addBuiltinFunction("gksNavier",
-                            std::function(
+  this->_addBuiltinFunction(
+    "gksNavier",
+    std::function(
 
-                              [](const std::shared_ptr<const DiscreteFunctionVariant>& rho,
-                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_U,
-                                 const std::shared_ptr<const DiscreteFunctionVariant>& rho_E,
-                                 const std::shared_ptr<const DiscreteFunctionVariant>& tau, const double& delta,
-                                 const double& dt) -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
-                                                                 std::shared_ptr<const DiscreteFunctionVariant>,
-                                                                 std::shared_ptr<const DiscreteFunctionVariant>> {
-                                return gksNavier(rho, rho_U, rho_E, tau, delta, dt);
-                              }
+      [](const std::shared_ptr<const DiscreteFunctionVariant>& rho,
+         const std::shared_ptr<const DiscreteFunctionVariant>& rhoU,
+         const std::shared_ptr<const DiscreteFunctionVariant>& rhoE,
+         const std::shared_ptr<const DiscreteFunctionVariant>& tau, const double& delta, const double& dt)
+        -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>, std::shared_ptr<const DiscreteFunctionVariant>,
+                      std::shared_ptr<const DiscreteFunctionVariant>> {
+        return GKSHandler{getCommonMesh({rho, rhoU, rhoE})}.solver().gksNavier(rho, rhoU, rhoE, tau, delta, dt);
+      }
 
-                              ));
+      ));
 
   this->_addBuiltinFunction("glace_solver",
                             std::function(
@@ -678,6 +678,15 @@ SchemeModule::SchemeModule()
 
                                              ));
 
+  this->_addBuiltinFunction("gks_inv_dt", std::function(
+
+                                            [](const std::shared_ptr<const DiscreteFunctionVariant>& c,
+                                               const std::shared_ptr<const DiscreteFunctionVariant>& U) -> double {
+                                              return gks_inv_dt(c, U);
+                                            }
+
+                                            ));
+
   this->_addBuiltinFunction("cell_volume",
                             std::function(
 
diff --git a/src/scheme/GKS.cpp b/src/scheme/GKS.cpp
index dad7b84a7..e69cf6c88 100644
--- a/src/scheme/GKS.cpp
+++ b/src/scheme/GKS.cpp
@@ -54,21 +54,21 @@ class GKS
     NodeValue<Rd> rho_U_flux_Euler(mesh.connectivity());
     NodeValue<double> rho_E_flux_Euler(mesh.connectivity());
     rho_flux_Euler.fill(0);
-    rho_U_flux_Euler.fill(TinyVector<1>(0));
+    rho_U_flux_Euler.fill(Rd(0));
     rho_E_flux_Euler.fill(0);
 
     NodeValue<double> rho_flux_Navier(mesh.connectivity());
     NodeValue<Rd> rho_U_flux_Navier(mesh.connectivity());
     NodeValue<double> rho_E_flux_Navier(mesh.connectivity());
     rho_flux_Navier.fill(0);
-    rho_U_flux_Navier.fill(TinyVector<1>(0));
+    rho_U_flux_Navier.fill(Rd(0));
     rho_E_flux_Navier.fill(0);
 
     NodeValue<double> rho_node(mesh.connectivity());
     NodeValue<Rd> rho_U_node(mesh.connectivity());
     NodeValue<double> rho_E_node(mesh.connectivity());
     rho_node.fill(0);
-    rho_U_node.fill(TinyVector<1>(0));
+    rho_U_node.fill(Rd(0));
     rho_E_node.fill(0);
     CellValue<double> lambda{p_mesh->connectivity()};
     // lambda.fill(0);
@@ -153,8 +153,8 @@ class GKS
 
       auto node_list = cell_to_node_matrix[cell_id];
 
-      rho_U_flux_Navier[node_list[1]][0] += 0.5 * F2_fn_left[0];
-      rho_E_flux_Navier[node_list[1]] += 0.5 * F3_fn_left;
+      rho_U_flux_Navier[node_list[1]][0] = 0.5 * F2_fn_left[0];
+      rho_E_flux_Navier[node_list[1]]    = 0.5 * F3_fn_left;
     }
 
     for (CellId cell_id = 1; cell_id < mesh.numberOfCells(); ++cell_id) {
@@ -176,7 +176,7 @@ class GKS
       rho_E_flux_Navier[node_list[0]] += 0.5 * F3_fn_right;
     }
 
-    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
+    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
       auto node_list = cell_to_node_matrix[cell_id];
 
       const double rho_flux_Euler_sum   = (rho_flux_Euler[node_list[1]] - rho_flux_Euler[node_list[0]]);
@@ -187,10 +187,10 @@ class GKS
       const double rho_E_flux_Navier_sum = (rho_E_flux_Navier[node_list[1]] - rho_E_flux_Navier[node_list[0]]);
       rho[cell_id] -= dt / Vj[cell_id] * (rho_flux_Euler_sum);
       rho_U[cell_id][0] -=
-        0 * dt / Vj[cell_id] *
+        dt / Vj[cell_id] *
         (rho_U_flux_Euler_sum[0] + eta[cell_id] * (rho_U_flux_Navier_sum[0] - rho_U_flux_Euler_sum[0]));
       rho_E[cell_id] -=
-        0 * dt / Vj[cell_id] * (rho_E_flux_Euler_sum + eta[cell_id] * (rho_E_flux_Navier_sum - rho_E_flux_Euler_sum));
+        dt / Vj[cell_id] * (rho_E_flux_Euler_sum + eta[cell_id] * (rho_E_flux_Navier_sum - rho_E_flux_Euler_sum));
     }
     return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
                            std::make_shared<DiscreteFunctionVariant>(rho_U),
diff --git a/src/scheme/GKS2.cpp b/src/scheme/GKS2.cpp
index 02f61d736..6ba5583e0 100644
--- a/src/scheme/GKS2.cpp
+++ b/src/scheme/GKS2.cpp
@@ -56,32 +56,32 @@ class GKS2
     NodeValuePerCell<Rd> rho_flux_G(mesh.connectivity());
     NodeValuePerCell<Rd> rho_U_flux_G(mesh.connectivity());
     NodeValuePerCell<Rd> rho_E_flux_G(mesh.connectivity());
-    rho_flux_G.fill(TinyVector<1>(0));
-    rho_U_flux_G.fill(TinyVector<1>(0));
-    rho_E_flux_G.fill(TinyVector<1>(0));
+    rho_flux_G.fill(Rd(0));
+    rho_U_flux_G.fill(Rd(0));
+    rho_E_flux_G.fill(Rd(0));
 
     NodeValuePerCell<Rd> rho_flux_Ffn(mesh.connectivity());
     NodeValuePerCell<Rd> rho_U_flux_Ffn(mesh.connectivity());
     NodeValuePerCell<Rd> rho_E_flux_Ffn(mesh.connectivity());
-    rho_flux_Ffn.fill(TinyVector<1>(0));
-    rho_U_flux_Ffn.fill(TinyVector<1>(0));
-    rho_E_flux_Ffn.fill(TinyVector<1>(0));
+    rho_flux_Ffn.fill(Rd(0));
+    rho_U_flux_Ffn.fill(Rd(0));
+    rho_E_flux_Ffn.fill(Rd(0));
 
     NodeValue<Rd> F2fn(mesh.connectivity());
     NodeValue<Rd> F3fn(mesh.connectivity());
-    F2fn.fill(TinyVector<1>(0));
-    F3fn.fill(TinyVector<1>(0));
+    F2fn.fill(Rd(0));
+    F3fn.fill(Rd(0));
 
     NodeValue<double> rho_node(mesh.connectivity());
     NodeValue<Rd> rho_U_node(mesh.connectivity());
     NodeValue<double> rho_E_node(mesh.connectivity());
     rho_node.fill(0);
-    rho_U_node.fill(TinyVector<1>(0));
+    rho_U_node.fill(zero);
     rho_E_node.fill(0);
     CellValue<double> lambda{p_mesh->connectivity()};
     // lambda.fill(0);
     CellValue<Rd> U{p_mesh->connectivity()};
-    // U.fill(TinyVector<1>(0));
+    // U.fill(Rd(0));
 
     parallel_for(
       mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
@@ -95,10 +95,10 @@ class GKS2
         const auto& node_cells = node_to_cell_matrix[node_id];
 
         double rho_cell_left    = 1;
-        Rd rho_U_cell_left      = TinyVector<1>(1);
+        Rd rho_U_cell_left      = Rd(1);
         double rho_E_cell_left  = 1;
         double rho_cell_right   = 1;
-        Rd rho_U_cell_right     = TinyVector<1>(1);
+        Rd rho_U_cell_right     = Rd(1);
         double rho_E_cell_right = 1;
 
         for (size_t l = 0; l < node_cells.size(); l++) {
@@ -146,10 +146,10 @@ class GKS2
       mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
         const auto& node_cells = node_to_cell_matrix[node_id];
 
-        Rd F2_fn_left  = TinyVector<1>(0);
-        Rd F3_fn_left  = TinyVector<1>(0);
-        Rd F2_fn_right = TinyVector<1>(0);
-        Rd F3_fn_right = TinyVector<1>(0);
+        Rd F2_fn_left  = Rd(0);
+        Rd F3_fn_left  = Rd(0);
+        Rd F2_fn_right = Rd(0);
+        Rd F3_fn_right = Rd(0);
 
         for (size_t l = 0; l < node_cells.size(); l++) {
           if (node_cells.size() == 1)
diff --git a/src/scheme/GKSNavier.cpp b/src/scheme/GKSNavier.cpp
index 9b8d4a615..1bfbdace6 100644
--- a/src/scheme/GKSNavier.cpp
+++ b/src/scheme/GKSNavier.cpp
@@ -1,33 +1,76 @@
 #include <scheme/GKSNavier.hpp>
 
-#include <mesh/Mesh.hpp>
-#include <mesh/MeshData.hpp>
-#include <mesh/MeshDataManager.hpp>
+#include <language/utils/InterpolateItemValue.hpp>
+#include <mesh/ItemValueUtils.hpp>
+#include <mesh/ItemValueVariant.hpp>
 #include <mesh/MeshTraits.hpp>
+#include <mesh/MeshVariant.hpp>
+#include <mesh/SubItemValuePerItemVariant.hpp>
+#include <scheme/DiscreteFunctionP0.hpp>
 #include <scheme/DiscreteFunctionUtils.hpp>
+#include <utils/Socket.hpp>
+
+#include <variant>
+#include <vector>
+
+double
+gks_inv_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
+           const std::shared_ptr<const DiscreteFunctionVariant>& U_v)
+{
+  const auto& c = c_v->get<DiscreteFunctionP0<const double>>();
+  const auto& U = U_v->get<DiscreteFunctionP0<const double>>();
+
+  return std::visit(
+    [&](auto&& p_mesh) -> double {
+      const auto& mesh = *p_mesh;
+
+      using MeshType = decltype(mesh);
+      if constexpr (is_polygonal_mesh_v<MeshType>) {
+        const auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();
+        const auto Sj = MeshDataManager::instance().getMeshData(mesh).sumOverRLjr();
+
+        CellValue<double> local_inv_dt{mesh.connectivity()};
+        parallel_for(
+          mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { local_inv_dt[j] = (c[j] + abs(U[j])) / Vj[j]; });
+
+        return max(local_inv_dt);
+      } else {
+        throw NormalError("unexpected mesh type");
+      }
+    },
+    c.meshVariant()->variant());
+}
 
 template <MeshConcept MeshType>
-class GKSNAVIER
+class GKSHandler::GKSNAVIER final : public GKSHandler::IGKSNAVIER
 {
- public:
-  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
-             std::shared_ptr<const DiscreteFunctionVariant>,
-             std::shared_ptr<const DiscreteFunctionVariant>>
-  solve(std::shared_ptr<const MeshType> p_mesh,
-        std::shared_ptr<const DiscreteFunctionVariant> rho_v,
-        std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
-        std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
-        std::shared_ptr<const DiscreteFunctionVariant> tau,
-        const double delta,
-        double dt)
-  {
-    using Rd = TinyVector<MeshType::Dimension>;
+ private:
+  static constexpr size_t Dimension = MeshType::Dimension;
+
+  using Rd   = TinyVector<Dimension>;
+  using Rdxd = TinyMatrix<Dimension>;
 
-    const MeshType& mesh = *p_mesh;
+  using MeshDataType = MeshData<MeshType>;
 
-    const double pi = std::acos(-1);
+  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
+  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;
 
-    DiscreteFunctionP0<const double> tau_n = tau->get<DiscreteFunctionP0<const double>>();
+  const double pi = std::acos(-1);
+
+ public:
+  std::tuple<const std::shared_ptr<const ItemValueVariant>,
+             const std::shared_ptr<const ItemValueVariant>,
+             const std::shared_ptr<const ItemValueVariant>>
+  compute_fluxes(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
+                 const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
+                 const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
+                 const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
+                 const double& delta) const
+  {
+    auto mesh_v          = getCommonMesh({rho_v, rhoU_v, rhoE_v});
+    const MeshType& mesh = *mesh_v->get<MeshType>();
+
+    DiscreteFunctionP0<const double> tau_n = tau_v->get<DiscreteScalarFunction>();
     CellValue<double> eta(mesh.connectivity());
     for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
       if (tau_n[cell_id] == 0)
@@ -35,18 +78,16 @@ class GKSNAVIER
       else
         eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
     }
-    // std::cout << "eta = " << eta << std::endl;
 
-    DiscreteFunctionP0<const double> rho_n   = rho_v->get<DiscreteFunctionP0<const double>>();
-    DiscreteFunctionP0<const Rd> rho_U_n     = rho_U_v->get<DiscreteFunctionP0<const Rd>>();
-    DiscreteFunctionP0<const double> rho_E_n = rho_E_v->get<DiscreteFunctionP0<const double>>();
+    DiscreteScalarFunction rho_n  = rho_v->get<DiscreteScalarFunction>();
+    DiscreteVectorFunction rhoU_n = rhoU_v->get<DiscreteVectorFunction>();
+    DiscreteScalarFunction rhoE_n = rhoE_v->get<DiscreteScalarFunction>();
 
-    DiscreteFunctionP0<double> rho   = copy(rho_n);
-    DiscreteFunctionP0<Rd> rho_U     = copy(rho_U_n);
-    DiscreteFunctionP0<double> rho_E = copy(rho_E_n);
+    DiscreteFunctionP0<double> rho  = copy(rho_n);
+    DiscreteFunctionP0<Rd> rhoU     = copy(rhoU_n);
+    DiscreteFunctionP0<double> rhoE = copy(rhoE_n);
 
     auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
-    auto Vj         = mesh_data.Vj();
 
     auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
     auto node_to_cell_matrix = mesh.connectivity().nodeToCellMatrix();
@@ -54,138 +95,148 @@ class GKSNAVIER
     const NodeValuePerCell<const Rd> njr = mesh_data.njr();
 
     NodeValuePerCell<Rd> rho_flux_G(mesh.connectivity());
-    NodeValuePerCell<Rd> rho_U_flux_G(mesh.connectivity());
-    NodeValuePerCell<Rd> rho_E_flux_G(mesh.connectivity());
-    rho_flux_G.fill(TinyVector<1>(0));
-    rho_U_flux_G.fill(TinyVector<1>(0));
-    rho_E_flux_G.fill(TinyVector<1>(0));
+    NodeValuePerCell<Rd> rhoU_flux_G(mesh.connectivity());
+    NodeValuePerCell<Rd> rhoE_flux_G(mesh.connectivity());
+    rho_flux_G.fill(Rd());
+    rhoU_flux_G.fill(Rd());
+    rhoE_flux_G.fill(Rd());
 
     NodeValuePerCell<Rd> rho_flux_Ffn(mesh.connectivity());
-    NodeValuePerCell<Rd> rho_U_flux_Ffn(mesh.connectivity());
-    NodeValuePerCell<Rd> rho_E_flux_Ffn(mesh.connectivity());
-    rho_flux_Ffn.fill(TinyVector<1>(0));
-    rho_U_flux_Ffn.fill(TinyVector<1>(0));
-    rho_E_flux_Ffn.fill(TinyVector<1>(0));
+    NodeValuePerCell<Rd> rhoU_flux_Ffn(mesh.connectivity());
+    NodeValuePerCell<Rd> rhoE_flux_Ffn(mesh.connectivity());
+    rho_flux_Ffn.fill(Rd());
+    rhoU_flux_Ffn.fill(Rd());
+    rhoE_flux_Ffn.fill(Rd());
+
+    CellValue<double> rho_fluxes(mesh.connectivity());
+    CellValue<Rd> rhoU_fluxes(mesh.connectivity());
+    CellValue<double> rhoE_fluxes(mesh.connectivity());
+    rho_fluxes.fill(0);
+    rhoU_fluxes.fill(Rd());
+    rhoE_fluxes.fill(0);
 
     NodeValue<Rd> F2fn(mesh.connectivity());
     NodeValue<Rd> F3fn(mesh.connectivity());
-    F2fn.fill(TinyVector<1>(0));
-    F3fn.fill(TinyVector<1>(0));
+    F2fn.fill(Rd());
+    F3fn.fill(Rd());
 
     NodeValue<double> rho_node(mesh.connectivity());
-    NodeValue<Rd> rho_U_node(mesh.connectivity());
-    NodeValue<double> rho_E_node(mesh.connectivity());
+    NodeValue<Rd> rhoU_node(mesh.connectivity());
+    NodeValue<double> rhoE_node(mesh.connectivity());
     rho_node.fill(0);
-    rho_U_node.fill(TinyVector<1>(0));
-    rho_E_node.fill(0);
-    CellValue<double> lambda{p_mesh->connectivity()};
-    // lambda.fill(0);
-    CellValue<Rd> U{p_mesh->connectivity()};
-    // U.fill(TinyVector<1>(0));
+    rhoU_node.fill(Rd());
+    rhoE_node.fill(0);
+    CellValue<double> lambda{mesh.connectivity()};
+    lambda.fill(0);
+    CellValue<Rd> U{mesh.connectivity()};
+    U.fill(Rd());
+
+    CellValue<double> err_function_term(mesh.connectivity());
+    CellValue<double> exp_term(mesh.connectivity());
 
     parallel_for(
       mesh.numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
-        U[cell_id][0]   = rho_U_n[cell_id][0] / rho_n[cell_id];
-        double rho_U_2  = rho_U_n[cell_id][0] * U[cell_id][0];
-        lambda[cell_id] = 0.5 * (1. + delta) * rho_n[cell_id] / (2 * rho_E_n[cell_id] - rho_U_2);
+        U[cell_id][0]   = rhoU_n[cell_id][0] / rho_n[cell_id];
+        double rhoU_2   = rhoU_n[cell_id][0] * U[cell_id][0];
+        lambda[cell_id] = 0.5 * (1. + delta) * rho_n[cell_id] / (2 * rhoE_n[cell_id] - rhoU_2);
+      });
+
+    parallel_for(
+      mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
+        const auto& node_cells = node_to_cell_matrix[node_id];
+        for (size_t l = 0; l < node_cells.size(); l++) {
+          if (node_cells.size() == 1)
+            continue;
+
+          if (l == 0) {
+            double U_2_left                  = U[node_cells[l]][0] * U[node_cells[l]][0];
+            err_function_term[node_cells[l]] = std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]);
+            exp_term[node_cells[l]] =
+              std::exp(-lambda[node_cells[l]] * U_2_left) / std::sqrt(pi * lambda[node_cells[l]]);
+          }
+        }
       });
 
     parallel_for(
       mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
         const auto& node_cells = node_to_cell_matrix[node_id];
 
-        double rho_cell_left    = 1;
-        Rd rho_U_cell_left      = TinyVector<1>(1);
-        double rho_E_cell_left  = 1;
-        double rho_cell_right   = 1;
-        Rd rho_U_cell_right     = TinyVector<1>(1);
-        double rho_E_cell_right = 1;
+        double rho_cell_left   = 1;
+        double rho_cell_right  = 1;
+        double rhoU_cell_left  = 1;
+        double rhoU_cell_right = 1;
+        double rhoE_cell_left  = 1;
+        double rhoE_cell_right = 1;
 
         for (size_t l = 0; l < node_cells.size(); l++) {
           if (node_cells.size() == 1)
             continue;
 
           if (l == 0) {
-            double U_2_left = U[node_cells[l]][0] * U[node_cells[l]][0];
+            rho_cell_left = rho_n[node_cells[l]] * (1 + err_function_term[node_cells[l]]);
 
-            rho_cell_left =
-              rho_n[node_cells[l]] * (1 + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+            rhoU_cell_left = rhoU_n[node_cells[l]][0] * (1. + err_function_term[node_cells[l]]) +
+                             rho_n[node_cells[l]] * exp_term[node_cells[l]];
 
-            rho_U_cell_left[0] =
-              rho_U_n[node_cells[l]][0] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
-              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_left) /
-                std::sqrt(pi * lambda[node_cells[l]]);
-
-            rho_E_cell_left =
-              rho_E_n[node_cells[l]] * (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
-              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
-                std::sqrt(pi * lambda[node_cells[l]]);
+            rhoE_cell_left = rhoE_n[node_cells[l]] * (1. + err_function_term[node_cells[l]]) +
+                             0.5 * rhoU_n[node_cells[l]][0] * exp_term[node_cells[l]];
           } else {
-            double U_2_right = U[node_cells[l]][0] * U[node_cells[l]][0];
-
-            rho_cell_right =
-              rho_n[node_cells[l]] * (1 - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0]));
+            rho_cell_right = rho_n[node_cells[l]] * (1 - err_function_term[node_cells[l]]);
 
-            rho_U_cell_right[0] =
-              rho_U_n[node_cells[l]][0] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
-              rho_n[node_cells[l]] * std::exp(-lambda[node_cells[l]] * U_2_right) /
-                std::sqrt(pi * lambda[node_cells[l]]);
+            rhoU_cell_right = rhoU_n[node_cells[l]][0] * (1. - err_function_term[node_cells[l]]) -
+                              rho_n[node_cells[l]] * exp_term[node_cells[l]];
 
-            rho_E_cell_right =
-              rho_E_n[node_cells[l]] * (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
-              0.5 * rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
-                std::sqrt(pi * lambda[node_cells[l]]);
+            rhoE_cell_right = rhoE_n[node_cells[l]] * (1. - err_function_term[node_cells[l]]) -
+                              0.5 * rhoU_n[node_cells[l]][0] * exp_term[node_cells[l]];
           }
         }
-        rho_node[node_id]   = 0.5 * (rho_cell_left + rho_cell_right);
-        rho_U_node[node_id] = 0.5 * (rho_U_cell_left + rho_U_cell_right);
-        rho_E_node[node_id] = 0.5 * (rho_E_cell_left + rho_E_cell_right);
+        rho_node[node_id]     = 0.5 * (rho_cell_left + rho_cell_right);
+        rhoU_node[node_id][0] = 0.5 * (rhoU_cell_left + rhoU_cell_right);
+        rhoE_node[node_id]    = 0.5 * (rhoE_cell_left + rhoE_cell_right);
       });
 
     parallel_for(
       mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
         const auto& node_cells = node_to_cell_matrix[node_id];
 
-        Rd F2_fn_left  = TinyVector<1>(0);
-        Rd F3_fn_left  = TinyVector<1>(0);
-        Rd F2_fn_right = TinyVector<1>(0);
-        Rd F3_fn_right = TinyVector<1>(0);
+        double F2_fn_left  = 0;
+        double F3_fn_left  = 0;
+        double F2_fn_right = 0;
+        double F3_fn_right = 0;
 
         for (size_t l = 0; l < node_cells.size(); l++) {
           if (node_cells.size() == 1)
             continue;
 
           if (l == 0) {
-            double U_2_left     = U[node_cells[l]][0] * U[node_cells[l]][0];
-            double rho_U_2_left = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
-
-            F2_fn_left[0] = (rho_U_2_left + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
-                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
-                            rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_left) /
-                              std::sqrt(pi * lambda[node_cells[l]]);
-
-            F3_fn_left[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_left + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
-                              (1. + std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) +
-                            0.5 * rho_n[node_cells[l]] * (U_2_left + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
-                              std::exp(-lambda[node_cells[l]] * U_2_left) / std::sqrt(pi * lambda[node_cells[l]]);
+            double U_2_left    = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rhoU_2_left = rhoU_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_left = (rhoU_2_left + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                           (1. + err_function_term[node_cells[l]]) +
+                         rhoU_n[node_cells[l]][0] * exp_term[node_cells[l]];
+
+            F3_fn_left = 0.5 * rhoU_n[node_cells[l]][0] * (U_2_left + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                           (1. + err_function_term[node_cells[l]]) +
+                         0.5 * rho_n[node_cells[l]] * (U_2_left + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                           exp_term[node_cells[l]];
           } else {
-            double U_2_right     = U[node_cells[l]][0] * U[node_cells[l]][0];
-            double rho_U_2_right = rho_U_n[node_cells[l]][0] * U[node_cells[l]][0];
-
-            F2_fn_right[0] = (rho_U_2_right + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
-                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
-                             rho_U_n[node_cells[l]][0] * std::exp(-lambda[node_cells[l]] * U_2_right) /
-                               std::sqrt(pi * lambda[node_cells[l]]);
-
-            F3_fn_right[0] = 0.5 * rho_U_n[node_cells[l]][0] * (U_2_right + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
-                               (1. - std::erf(std::sqrt(lambda[node_cells[l]]) * U[node_cells[l]][0])) -
-                             0.5 * rho_n[node_cells[l]] * (U_2_right + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
-                               std::exp(-lambda[node_cells[l]] * U_2_right) / std::sqrt(pi * lambda[node_cells[l]]);
+            double U_2_right    = U[node_cells[l]][0] * U[node_cells[l]][0];
+            double rhoU_2_right = rhoU_n[node_cells[l]][0] * U[node_cells[l]][0];
+
+            F2_fn_right = (rhoU_2_right + 0.5 * rho_n[node_cells[l]] / lambda[node_cells[l]]) *
+                            (1. - err_function_term[node_cells[l]]) -
+                          rhoU_n[node_cells[l]][0] * exp_term[node_cells[l]];
+
+            F3_fn_right = 0.5 * rhoU_n[node_cells[l]][0] * (U_2_right + 0.5 * (delta + 3) / lambda[node_cells[l]]) *
+                            (1. - err_function_term[node_cells[l]]) -
+                          0.5 * rho_n[node_cells[l]] * (U_2_right + 0.5 * (delta + 2) / lambda[node_cells[l]]) *
+                            exp_term[node_cells[l]];
           }
         }
 
-        F2fn[node_id] = 0.5 * (F2_fn_left + F2_fn_right);
-        F3fn[node_id] = 0.5 * (F3_fn_left + F3_fn_right);
+        F2fn[node_id][0] = 0.5 * (F2_fn_left + F2_fn_right);
+        F3fn[node_id][0] = 0.5 * (F3_fn_left + F3_fn_right);
       });
 
     parallel_for(
@@ -193,85 +244,132 @@ class GKSNAVIER
         const auto& cell_nodes = cell_to_node_matrix[cell_id];
 
         for (size_t r = 0; r < cell_nodes.size(); r++) {
-          double rho_U_2_node = rho_U_node[cell_nodes[r]][0] * rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]];
+          double rhoU_2_node = rhoU_node[cell_nodes[r]][0] * rhoU_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]];
 
-          rho_flux_G(cell_id, r) = rho_U_node[cell_nodes[r]][0] * njr(cell_id, r);
-          rho_U_flux_G(cell_id, r) =
-            (delta * rho_U_2_node / (1. + delta) + 2 * rho_E_node[cell_nodes[r]] / (1. + delta)) * njr(cell_id, r);
-          rho_E_flux_G(cell_id, r) =
-            rho_U_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]] *
-            ((3. + delta) * rho_E_node[cell_nodes[r]] / (1. + delta) - rho_U_2_node / (1. + delta)) * njr(cell_id, r);
+          rho_flux_G(cell_id, r) = rhoU_node[cell_nodes[r]][0] * njr(cell_id, r);
+          rhoU_flux_G(cell_id, r) =
+            (delta * rhoU_2_node / (1. + delta) + 2 * rhoE_node[cell_nodes[r]] / (1. + delta)) * njr(cell_id, r);
+          rhoE_flux_G(cell_id, r) =
+            rhoU_node[cell_nodes[r]][0] / rho_node[cell_nodes[r]] *
+            ((3. + delta) * rhoE_node[cell_nodes[r]] / (1. + delta) - rhoU_2_node / (1. + delta)) * njr(cell_id, r);
 
-          rho_U_flux_Ffn(cell_id, r) = F2fn[cell_nodes[r]][0] * njr(cell_id, r);
-          rho_E_flux_Ffn(cell_id, r) = F3fn[cell_nodes[r]][0] * njr(cell_id, r);
+          rhoU_flux_Ffn(cell_id, r) = F2fn[cell_nodes[r]][0] * njr(cell_id, r);
+          rhoE_flux_Ffn(cell_id, r) = F3fn[cell_nodes[r]][0] * njr(cell_id, r);
         }
       });
 
-    // // for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
-    // // std::cout << "flux_E_Ffn : " << rho_E_flux_Ffn(cell_id, 0) << std::endl;
-    // // }
-
     for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
       const size_t& nb_nodes = rho_flux_G.numberOfSubValues(cell_id);
       for (size_t r = 0; r < nb_nodes; r++) {
-        rho[cell_id] -= dt / Vj[cell_id] * rho_flux_G(cell_id, r)[0];
-        // rho_U[cell_id] -= dt / Vj[cell_id] *
-        //                   ((1 - eta[cell_id]) * rho_U_flux_G(cell_id, r) + eta[cell_id] *
-        // (rho_U_flux_Ffn(cell_id,
-        //                   r)));
-        // rho_E[cell_id] -=
-        //   dt / Vj[cell_id] *
-        //   ((1 - eta[cell_id]) * rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0]));
-        rho_U[cell_id] -=
-          dt / Vj[cell_id] *
-          (rho_U_flux_G(cell_id, r) + eta[cell_id] * (rho_U_flux_Ffn(cell_id, r) - rho_U_flux_G(cell_id, r)));
-        rho_E[cell_id] -=
-          dt / Vj[cell_id] *
-          (rho_E_flux_G(cell_id, r)[0] + eta[cell_id] * (rho_E_flux_Ffn(cell_id, r)[0] - rho_E_flux_G(cell_id, r)[0]));
+        rho_fluxes[cell_id] += rho_flux_G(cell_id, r)[0];
+        rhoU_fluxes[cell_id] +=
+          ((1 - eta[cell_id]) * rhoU_flux_G(cell_id, r) + eta[cell_id] * (rhoU_flux_Ffn(cell_id, r)));
+        rhoE_fluxes[cell_id] +=
+          ((1 - eta[cell_id]) * rhoE_flux_G(cell_id, r)[0] + eta[cell_id] * (rhoE_flux_Ffn(cell_id, r)[0]));
       }
     }
 
-    return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
-                           std::make_shared<DiscreteFunctionVariant>(rho_U),
-                           std::make_shared<DiscreteFunctionVariant>(rho_E));
+    return std::make_tuple(std::make_shared<const ItemValueVariant>(rho_fluxes),
+                           std::make_shared<const ItemValueVariant>(rhoU_fluxes),
+                           std::make_shared<const ItemValueVariant>(rhoE_fluxes));
+  }
+
+  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>>
+  apply_fluxes(const MeshType& mesh,
+               const DiscreteScalarFunction& rho,
+               const DiscreteVectorFunction& rhoU,
+               const DiscreteScalarFunction& rhoE,
+               const CellValue<const double>& rho_fluxes,
+               const CellValue<const Rd>& rhoU_fluxes,
+               const CellValue<const double>& rhoE_fluxes,
+               const double& dt) const
+  {
+    CellValue<double> new_rho  = copy(rho.cellValues());
+    CellValue<Rd> new_rhoU     = copy(rhoU.cellValues());
+    CellValue<double> new_rhoE = copy(rhoE.cellValues());
+
+    NodeValue<Rd> xr = copy(mesh.xr());
+
+    std::shared_ptr<const MeshType> mesh_v = std::make_shared<MeshType>(mesh.shared_connectivity(), xr);
+
+    auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+    auto Vj         = mesh_data.Vj();
+
+    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
+      new_rho[cell_id] -= dt / Vj[cell_id] * rho_fluxes[cell_id];
+      new_rhoU[cell_id] -= dt / Vj[cell_id] * rhoU_fluxes[cell_id];
+      new_rhoE[cell_id] -= dt / Vj[cell_id] * rhoE_fluxes[cell_id];
+    }
+
+    return {std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(mesh_v, new_rho)),
+            std::make_shared<DiscreteFunctionVariant>(DiscreteVectorFunction(mesh_v, new_rhoU)),
+            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction(mesh_v, new_rhoE))};
   }
 
-  GKSNAVIER() = default;
+  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>,
+             std::shared_ptr<const DiscreteFunctionVariant>>
+  apply_fluxes(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
+               const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
+               const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
+               const std::shared_ptr<const ItemValueVariant>& rho_fluxes,
+               const std::shared_ptr<const ItemValueVariant>& rhoU_fluxes,
+               const std::shared_ptr<const ItemValueVariant>& rhoE_fluxes,
+               const double& dt) const
+  {
+    std::shared_ptr mesh_v = getCommonMesh({rho_v, rhoU_v, rhoE_v});
+    if (not mesh_v) {
+      throw NormalError("discrete functions are not defined on the same mesh");
+    }
+
+    if (not checkDiscretizationType({rho_v, rhoU_v, rhoE_v}, DiscreteFunctionType::P0)) {
+      throw NormalError("GKS solver expects P0 functions");
+    }
+
+    return this->apply_fluxes(*mesh_v->get<MeshType>(),                      //
+                              rho_v->get<DiscreteScalarFunction>(),          //
+                              rhoU_v->get<DiscreteVectorFunction>(),         //
+                              rhoE_v->get<DiscreteScalarFunction>(),         //
+                              rho_fluxes->get<CellValue<const double>>(),    //
+                              rhoU_fluxes->get<CellValue<const Rd>>(),       //
+                              rhoE_fluxes->get<CellValue<const double>>(),   //
+                              dt);
+  }
+
+  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+             std::shared_ptr<const DiscreteFunctionVariant>,   // U
+             std::shared_ptr<const DiscreteFunctionVariant>>   // E
+  gksNavier(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
+            const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
+            const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
+            const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
+            const double& delta,
+            const double& dt) const
+  {
+    std::shared_ptr mesh_v = getCommonMesh({rho_v, rhoU_v, rhoE_v});
+
+    auto [rho_fluxes, rhoU_fluxes, rhoE_fluxes] = compute_fluxes(rho_v, rhoU_v, rhoE_v, tau_v, delta);
+    return apply_fluxes(rho_v, rhoU_v, rhoE_v, rho_fluxes, rhoU_fluxes, rhoE_fluxes, dt);
+  }
+
+  GKSNAVIER()            = default;
+  GKSNAVIER(GKSNAVIER&&) = default;
+  ~GKSNAVIER()           = default;
 };
 
-std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
-           std::shared_ptr<const DiscreteFunctionVariant>,   // U
-           std::shared_ptr<const DiscreteFunctionVariant>>   // E
-gksNavier(std::shared_ptr<const DiscreteFunctionVariant> rho_v,
-          std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
-          std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
-          std::shared_ptr<const DiscreteFunctionVariant> tau,
-          const double delta,
-          const double dt)
+GKSHandler::GKSHandler(const std::shared_ptr<const MeshVariant>& mesh_v)
 {
-  std::shared_ptr mesh_v = getCommonMesh({rho_v, rho_U_v, rho_E_v});
   if (not mesh_v) {
     throw NormalError("discrete functions are not defined on the same mesh");
   }
 
-  if (not checkDiscretizationType({rho_v, rho_U_v, rho_E_v}, DiscreteFunctionType::P0)) {
-    throw NormalError("GKS solver expects P0 functions");
-  }
-
-  return std::visit(
-    [&](auto&& p_mesh)
-      -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>, std::shared_ptr<const DiscreteFunctionVariant>,
-                    std::shared_ptr<const DiscreteFunctionVariant>> {
-      using MeshType = std::decay_t<decltype(*p_mesh)>;
+  std::visit(
+    [&](auto&& mesh) {
+      using MeshType = mesh_type_t<decltype(mesh)>;
       if constexpr (is_polygonal_mesh_v<MeshType>) {
-        if constexpr (MeshType::Dimension == 1) {
-          GKSNAVIER<MeshType> gksNavier;
-          return gksNavier.solve(p_mesh, rho_v, rho_U_v, rho_E_v, tau, delta, dt);
-
-        } else {
-          throw NormalError("dimension not treated");
-        }
-
+        m_gks_navier = std::make_unique<GKSNAVIER<MeshType>>();
       } else {
         throw NormalError("unexpected mesh type");
       }
diff --git a/src/scheme/GKSNavier.hpp b/src/scheme/GKSNavier.hpp
index 9f6135f1c..06103195a 100644
--- a/src/scheme/GKSNavier.hpp
+++ b/src/scheme/GKSNavier.hpp
@@ -1,18 +1,84 @@
 #ifndef GKSNAVIER_HPP
 #define GKSNAVIER_HPP
 
+#include <mesh/MeshTraits.hpp>
 #include <scheme/DiscreteFunctionVariant.hpp>
 
+#include <memory>
 #include <tuple>
+#include <vector>
 
-std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
-           std::shared_ptr<const DiscreteFunctionVariant>,   // rhoU
-           std::shared_ptr<const DiscreteFunctionVariant>>   // rhoE
-gksNavier(std::shared_ptr<const DiscreteFunctionVariant> rho,
-          std::shared_ptr<const DiscreteFunctionVariant> rhoU,
-          std::shared_ptr<const DiscreteFunctionVariant> rhoE,
-          std::shared_ptr<const DiscreteFunctionVariant> tau,
-          const double delta,
-          const double dt);
+class DiscreteFunctionVariant;
+class IBoundaryConditionDescriptor;
+class MeshVariant;
+class ItemValueVariant;
+class SubItemValuePerItemVariant;
+
+double gks_inv_dt(const std::shared_ptr<const DiscreteFunctionVariant>& c,
+                  const std::shared_ptr<const DiscreteFunctionVariant>& U);
+
+class GKSHandler
+{
+ public:
+  enum class SolverType
+  {
+    Glace,
+    Eucclhyd
+  };
+
+ private:
+  struct IGKSNAVIER
+  {
+    virtual std::tuple<const std::shared_ptr<const ItemValueVariant>,
+                       const std::shared_ptr<const ItemValueVariant>,
+                       const std::shared_ptr<const ItemValueVariant>>
+    compute_fluxes(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
+                   const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
+                   const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
+                   const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
+                   const double& delta) const = 0;
+
+    virtual std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
+                       std::shared_ptr<const DiscreteFunctionVariant>,
+                       std::shared_ptr<const DiscreteFunctionVariant>>
+    apply_fluxes(const std::shared_ptr<const DiscreteFunctionVariant>& rho,
+                 const std::shared_ptr<const DiscreteFunctionVariant>& rhoU,
+                 const std::shared_ptr<const DiscreteFunctionVariant>& rhoE,
+                 const std::shared_ptr<const ItemValueVariant>& rho_fluxes,
+                 const std::shared_ptr<const ItemValueVariant>& rhoU_fluxes,
+                 const std::shared_ptr<const ItemValueVariant>& rhoE_fluxes,
+                 const double& dt) const = 0;
+
+    virtual std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
+                       std::shared_ptr<const DiscreteFunctionVariant>,   // rhoU
+                       std::shared_ptr<const DiscreteFunctionVariant>>   // rhoE
+    gksNavier(const std::shared_ptr<const DiscreteFunctionVariant>& rho_v,
+              const std::shared_ptr<const DiscreteFunctionVariant>& rhoU_v,
+              const std::shared_ptr<const DiscreteFunctionVariant>& rhoE_v,
+              const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
+              const double& delta,
+              const double& dt) const = 0;
+
+    IGKSNAVIER()                        = default;
+    IGKSNAVIER(IGKSNAVIER&&)            = default;
+    IGKSNAVIER& operator=(IGKSNAVIER&&) = default;
+
+    virtual ~IGKSNAVIER() = default;
+  };
+
+  template <MeshConcept MeshType>
+  class GKSNAVIER;
+
+  std::unique_ptr<IGKSNAVIER> m_gks_navier;
+
+ public:
+  const IGKSNAVIER&
+  solver() const
+  {
+    return *m_gks_navier;
+  }
+
+  GKSHandler(const std::shared_ptr<const MeshVariant>& mesh_v);
+};
 
 #endif   // GKSNAVIER_HPP
-- 
GitLab