From 29a73e7fc8d7b455aa3275e04f66bdcf1744fdda Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Wed, 6 Mar 2024 08:32:01 +0100
Subject: [PATCH] Change MeshData template parameter

It now uses mesh type template argument instead of space dimension
---
 .../utils/ItemArrayVariantFunctionInterpoler.cpp |  2 +-
 .../utils/ItemValueVariantFunctionInterpoler.cpp |  2 +-
 src/mesh/DiamondDualMeshBuilder.cpp              |  2 +-
 src/mesh/Dual1DMeshBuilder.cpp                   |  2 +-
 src/mesh/MedianDualMeshBuilder.cpp               |  2 +-
 src/mesh/MeshData.cpp                            | 10 +++++-----
 src/mesh/MeshData.hpp                            | 10 +++++++---
 src/mesh/MeshDataManager.cpp                     | 16 ++++++++--------
 src/mesh/MeshDataManager.hpp                     |  7 ++++---
 src/scheme/AcousticSolver.cpp                    |  4 ++--
 src/scheme/DiscreteFunctionInterpoler.cpp        | 10 ++++------
 src/scheme/DiscreteFunctionVectorInterpoler.cpp  |  8 +++-----
 src/scheme/FluxingAdvectionSolver.cpp            |  7 +++----
 src/scheme/HyperelasticSolver.cpp                |  6 +++---
 14 files changed, 44 insertions(+), 44 deletions(-)

diff --git a/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp
index 1c4b2b9f0..bac8b5c54 100644
--- a/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp
+++ b/src/language/utils/ItemArrayVariantFunctionInterpoler.cpp
@@ -18,7 +18,7 @@ ItemArrayVariantFunctionInterpoler::_interpolate() const
 {
   std::shared_ptr p_mesh     = m_mesh_v->get<MeshType>();
   constexpr size_t Dimension = MeshType::Dimension;
-  using MeshDataType         = MeshData<Dimension>;
+  using MeshDataType         = MeshData<MeshType>;
 
   switch (m_item_type) {
   case ItemType::cell: {
diff --git a/src/language/utils/ItemValueVariantFunctionInterpoler.cpp b/src/language/utils/ItemValueVariantFunctionInterpoler.cpp
index 74d07b577..be073b907 100644
--- a/src/language/utils/ItemValueVariantFunctionInterpoler.cpp
+++ b/src/language/utils/ItemValueVariantFunctionInterpoler.cpp
@@ -18,7 +18,7 @@ ItemValueVariantFunctionInterpoler::_interpolate() const
 {
   std::shared_ptr p_mesh     = m_mesh_v->get<MeshType>();
   constexpr size_t Dimension = MeshType::Dimension;
-  using MeshDataType         = MeshData<Dimension>;
+  using MeshDataType         = MeshData<MeshType>;
 
   switch (m_item_type) {
   case ItemType::cell: {
diff --git a/src/mesh/DiamondDualMeshBuilder.cpp b/src/mesh/DiamondDualMeshBuilder.cpp
index f6ecba946..43f5528a2 100644
--- a/src/mesh/DiamondDualMeshBuilder.cpp
+++ b/src/mesh/DiamondDualMeshBuilder.cpp
@@ -29,7 +29,7 @@ DiamondDualMeshBuilder::_buildDualDiamondMeshFrom(const MeshType& primal_mesh)
 
   const NodeValue<const TinyVector<Dimension>> primal_xr = primal_mesh.xr();
 
-  MeshData<Dimension>& primal_mesh_data                  = MeshDataManager::instance().getMeshData(primal_mesh);
+  MeshData<MeshType>& primal_mesh_data                   = MeshDataManager::instance().getMeshData(primal_mesh);
   const CellValue<const TinyVector<Dimension>> primal_xj = primal_mesh_data.xj();
 
   std::shared_ptr primal_to_diamond_dual_connectivity_data_mapper =
diff --git a/src/mesh/Dual1DMeshBuilder.cpp b/src/mesh/Dual1DMeshBuilder.cpp
index c861a8bf9..b39cd000c 100644
--- a/src/mesh/Dual1DMeshBuilder.cpp
+++ b/src/mesh/Dual1DMeshBuilder.cpp
@@ -29,7 +29,7 @@ Dual1DMeshBuilder::_buildDual1DMeshFrom(const MeshType& primal_mesh)
 
   const NodeValue<const TinyVector<1>> primal_xr = primal_mesh.xr();
 
-  MeshData<1>& primal_mesh_data                  = MeshDataManager::instance().getMeshData(primal_mesh);
+  MeshData<MeshType>& primal_mesh_data           = MeshDataManager::instance().getMeshData(primal_mesh);
   const CellValue<const TinyVector<1>> primal_xj = primal_mesh_data.xj();
 
   std::shared_ptr primal_to_dual_1d_connectivity_data_mapper =
diff --git a/src/mesh/MedianDualMeshBuilder.cpp b/src/mesh/MedianDualMeshBuilder.cpp
index 5a22cf0b2..411005ec9 100644
--- a/src/mesh/MedianDualMeshBuilder.cpp
+++ b/src/mesh/MedianDualMeshBuilder.cpp
@@ -29,7 +29,7 @@ MedianDualMeshBuilder::_buildMedianDualMeshFrom(const MeshType& primal_mesh)
 
   const NodeValue<const TinyVector<Dimension>> primal_xr = primal_mesh.xr();
 
-  MeshData<Dimension>& primal_mesh_data                  = MeshDataManager::instance().getMeshData(primal_mesh);
+  MeshData<MeshType>& primal_mesh_data                   = MeshDataManager::instance().getMeshData(primal_mesh);
   const CellValue<const TinyVector<Dimension>> primal_xj = primal_mesh_data.xj();
   const FaceValue<const TinyVector<Dimension>> primal_xl = primal_mesh_data.xl();
 
diff --git a/src/mesh/MeshData.cpp b/src/mesh/MeshData.cpp
index 30ef6a015..40e387862 100644
--- a/src/mesh/MeshData.cpp
+++ b/src/mesh/MeshData.cpp
@@ -10,11 +10,11 @@
 
 template <size_t Dimension>
 void
-MeshData<Dimension>::_storeBadMesh()
+MeshData<Mesh<Dimension>>::_storeBadMesh()
 {
   VTKWriter writer("bad_mesh");
   writer.writeOnMesh(std::make_shared<MeshVariant>(
-                       std::make_shared<const MeshType>(m_mesh.shared_connectivity(), m_mesh.xr())),
+                       std::make_shared<const Mesh<Dimension>>(m_mesh.shared_connectivity(), m_mesh.xr())),
                      {std::make_shared<NamedItemValueVariant>(std::make_shared<ItemValueVariant>(m_Vj), "volume")});
   std::ostringstream error_msg;
   error_msg << "mesh contains cells of non-positive volume (see " << rang::fgB::yellow << "bad_mesh.pvd"
@@ -22,6 +22,6 @@ MeshData<Dimension>::_storeBadMesh()
   throw NormalError(error_msg.str());
 }
 
-template void MeshData<1>::_storeBadMesh();
-template void MeshData<2>::_storeBadMesh();
-template void MeshData<3>::_storeBadMesh();
+template void MeshData<Mesh<1>>::_storeBadMesh();
+template void MeshData<Mesh<2>>::_storeBadMesh();
+template void MeshData<Mesh<3>>::_storeBadMesh();
diff --git a/src/mesh/MeshData.hpp b/src/mesh/MeshData.hpp
index a198aab45..d5eba5f5d 100644
--- a/src/mesh/MeshData.hpp
+++ b/src/mesh/MeshData.hpp
@@ -4,6 +4,7 @@
 #include <algebra/TinyVector.hpp>
 #include <mesh/IMeshData.hpp>
 #include <mesh/ItemValue.hpp>
+#include <mesh/MeshTraits.hpp>
 #include <mesh/SubItemValuePerItem.hpp>
 #include <utils/Messenger.hpp>
 #include <utils/PugsUtils.hpp>
@@ -11,15 +12,18 @@
 template <size_t Dimension>
 class Mesh;
 
+template <MeshConcept MeshType>
+class MeshData;
+
 template <size_t Dimension>
-class MeshData : public IMeshData
+class MeshData<Mesh<Dimension>> : public IMeshData
 {
  public:
+  using MeshType = Mesh<Dimension>;
+
   static_assert(Dimension > 0, "dimension must be strictly positive");
   static_assert((Dimension <= 3), "only 1d, 2d and 3d are implemented");
 
-  using MeshType = Mesh<Dimension>;
-
   using Rd = TinyVector<Dimension>;
 
   static constexpr double inv_Dimension = 1. / Dimension;
diff --git a/src/mesh/MeshDataManager.cpp b/src/mesh/MeshDataManager.cpp
index 3607fcad6..4eb8f5677 100644
--- a/src/mesh/MeshDataManager.cpp
+++ b/src/mesh/MeshDataManager.cpp
@@ -40,21 +40,21 @@ MeshDataManager::deleteMeshData(const size_t mesh_id)
   m_mesh_id_mesh_data_map.erase(mesh_id);
 }
 
-template <size_t Dimension>
-MeshData<Dimension>&
-MeshDataManager::getMeshData(const Mesh<Dimension>& mesh)
+template <MeshConcept MeshType>
+MeshData<MeshType>&
+MeshDataManager::getMeshData(const MeshType& mesh)
 {
   if (auto i_mesh_data = m_mesh_id_mesh_data_map.find(mesh.id()); i_mesh_data != m_mesh_id_mesh_data_map.end()) {
-    return dynamic_cast<MeshData<Dimension>&>(*i_mesh_data->second);
+    return dynamic_cast<MeshData<MeshType>&>(*i_mesh_data->second);
   } else {
     // **cannot** use make_shared since MeshData constructor is **private**
-    std::shared_ptr<MeshData<Dimension>> mesh_data{new MeshData<Dimension>(mesh)};
+    std::shared_ptr<MeshData<MeshType>> mesh_data{new MeshData<MeshType>(mesh)};
 
     m_mesh_id_mesh_data_map[mesh.id()] = mesh_data;
     return *mesh_data;
   }
 }
 
-template MeshData<1>& MeshDataManager::getMeshData(const Mesh<1>&);
-template MeshData<2>& MeshDataManager::getMeshData(const Mesh<2>&);
-template MeshData<3>& MeshDataManager::getMeshData(const Mesh<3>&);
+template MeshData<Mesh<1>>& MeshDataManager::getMeshData(const Mesh<1>&);
+template MeshData<Mesh<2>>& MeshDataManager::getMeshData(const Mesh<2>&);
+template MeshData<Mesh<3>>& MeshDataManager::getMeshData(const Mesh<3>&);
diff --git a/src/mesh/MeshDataManager.hpp b/src/mesh/MeshDataManager.hpp
index 0d2bf964b..b47756c01 100644
--- a/src/mesh/MeshDataManager.hpp
+++ b/src/mesh/MeshDataManager.hpp
@@ -2,6 +2,7 @@
 #define MESH_DATA_MANAGER_HPP
 
 #include <mesh/IMeshData.hpp>
+#include <mesh/MeshTraits.hpp>
 #include <utils/PugsAssert.hpp>
 #include <utils/PugsMacros.hpp>
 
@@ -11,7 +12,7 @@
 template <size_t Dimension>
 class Mesh;
 
-template <size_t Dimension>
+template <MeshConcept MeshType>
 class MeshData;
 
 class MeshDataManager
@@ -41,8 +42,8 @@ class MeshDataManager
 
   void deleteMeshData(const size_t mesh_id);
 
-  template <size_t Dimension>
-  MeshData<Dimension>& getMeshData(const Mesh<Dimension>&);
+  template <MeshConcept MeshType>
+  MeshData<MeshType>& getMeshData(const MeshType&);
 };
 
 #endif   // MESH_DATA_MANAGER_HPP
diff --git a/src/scheme/AcousticSolver.cpp b/src/scheme/AcousticSolver.cpp
index 9f00bf12b..43476df57 100644
--- a/src/scheme/AcousticSolver.cpp
+++ b/src/scheme/AcousticSolver.cpp
@@ -57,7 +57,7 @@ class AcousticSolverHandler::AcousticSolver final : public AcousticSolverHandler
   using Rdxd = TinyMatrix<Dimension>;
   using Rd   = TinyVector<Dimension>;
 
-  using MeshDataType = MeshData<Dimension>;
+  using MeshDataType = MeshData<MeshType>;
 
   using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
   using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;
@@ -525,7 +525,7 @@ AcousticSolverHandler::AcousticSolver<MeshType>::_applyPressureBC(const Boundary
       [&](auto&& bc) {
         using T = std::decay_t<decltype(bc)>;
         if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
-          MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+          MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(mesh);
           if constexpr (Dimension == 1) {
             const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();
 
diff --git a/src/scheme/DiscreteFunctionInterpoler.cpp b/src/scheme/DiscreteFunctionInterpoler.cpp
index 6c5228a91..f8e5a58ca 100644
--- a/src/scheme/DiscreteFunctionInterpoler.cpp
+++ b/src/scheme/DiscreteFunctionInterpoler.cpp
@@ -17,9 +17,8 @@ DiscreteFunctionInterpoler::_interpolateOnZoneList() const
 
   constexpr size_t Dimension = MeshType::Dimension;
 
-  std::shared_ptr p_mesh  = m_mesh->get<MeshType>();
-  using MeshDataType      = MeshData<Dimension>;
-  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
+  std::shared_ptr p_mesh        = m_mesh->get<MeshType>();
+  MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
 
   CellValue<bool> is_in_zone{p_mesh->connectivity()};
   is_in_zone.fill(false);
@@ -76,9 +75,8 @@ DiscreteFunctionInterpoler::_interpolateGlobally() const
   Assert(m_zone_list.size() == 0, "invalid call when zones are defined");
   constexpr size_t Dimension = MeshType::Dimension;
 
-  std::shared_ptr p_mesh  = m_mesh->get<MeshType>();
-  using MeshDataType      = MeshData<Dimension>;
-  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
+  std::shared_ptr p_mesh        = m_mesh->get<MeshType>();
+  MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
 
   if constexpr (std::is_same_v<DataType, ValueType>) {
     return DiscreteFunctionP0<ValueType>(m_mesh,
diff --git a/src/scheme/DiscreteFunctionVectorInterpoler.cpp b/src/scheme/DiscreteFunctionVectorInterpoler.cpp
index 547709c45..fffead074 100644
--- a/src/scheme/DiscreteFunctionVectorInterpoler.cpp
+++ b/src/scheme/DiscreteFunctionVectorInterpoler.cpp
@@ -14,9 +14,8 @@ DiscreteFunctionVectorInterpoler::_interpolateOnZoneList() const
   Assert(m_zone_list.size() > 0, "no zone list provided");
   constexpr size_t Dimension = MeshType::Dimension;
 
-  std::shared_ptr p_mesh  = m_mesh->get<MeshType>();
-  using MeshDataType      = MeshData<Dimension>;
-  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
+  std::shared_ptr p_mesh        = m_mesh->get<MeshType>();
+  MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
 
   CellValue<bool> is_in_zone{p_mesh->connectivity()};
   is_in_zone.fill(false);
@@ -75,8 +74,7 @@ DiscreteFunctionVectorInterpoler::_interpolateGlobally() const
 
   std::shared_ptr p_mesh = m_mesh->get<MeshType>();
 
-  using MeshDataType      = MeshData<Dimension>;
-  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
+  MeshData<MeshType>& mesh_data = MeshDataManager::instance().getMeshData(*p_mesh);
 
   return DiscreteFunctionP0Vector<DataType>(m_mesh,
                                             InterpolateItemArray<DataType(TinyVector<Dimension>)>::template interpolate<
diff --git a/src/scheme/FluxingAdvectionSolver.cpp b/src/scheme/FluxingAdvectionSolver.cpp
index e1bf11a8b..509156b1e 100644
--- a/src/scheme/FluxingAdvectionSolver.cpp
+++ b/src/scheme/FluxingAdvectionSolver.cpp
@@ -33,8 +33,7 @@ class FluxingAdvectionSolver
  private:
   static constexpr size_t Dimension = MeshType::Dimension;
 
-  using Rd           = TinyVector<Dimension>;
-  using MeshDataType = MeshData<Dimension>;
+  using Rd = TinyVector<Dimension>;
 
   const std::shared_ptr<const MeshType> m_old_mesh;
   const std::shared_ptr<const MeshType> m_new_mesh;
@@ -421,7 +420,7 @@ FluxingAdvectionSolver<MeshType>::_computeCycleNumber(FaceValue<double> fluxing_
       }
     });
 
-  MeshData<Dimension>& mesh_data   = MeshDataManager::instance().getMeshData(*m_old_mesh);
+  MeshData<MeshType>& mesh_data    = MeshDataManager::instance().getMeshData(*m_old_mesh);
   const CellValue<const double> Vj = mesh_data.Vj();
   CellValue<size_t> ratio(m_old_mesh->connectivity());
 
@@ -641,7 +640,7 @@ FluxingAdvectionSolver<MeshType>::_remapAllQuantities()
   const auto cell_to_face_matrix              = m_new_mesh->connectivity().cellToFaceMatrix();
   const auto face_local_number_in_their_cells = m_new_mesh->connectivity().faceLocalNumbersInTheirCells();
 
-  MeshData<Dimension>& old_mesh_data = MeshDataManager::instance().getMeshData(*m_old_mesh);
+  MeshData<MeshType>& old_mesh_data = MeshDataManager::instance().getMeshData(*m_old_mesh);
 
   const CellValue<const double> old_Vj = old_mesh_data.Vj();
   const CellValue<double> step_Vj      = copy(old_Vj);
diff --git a/src/scheme/HyperelasticSolver.cpp b/src/scheme/HyperelasticSolver.cpp
index 4fc391885..5309385ec 100644
--- a/src/scheme/HyperelasticSolver.cpp
+++ b/src/scheme/HyperelasticSolver.cpp
@@ -57,7 +57,7 @@ class HyperelasticSolverHandler::HyperelasticSolver final : public HyperelasticS
   using Rdxd = TinyMatrix<Dimension>;
   using Rd   = TinyVector<Dimension>;
 
-  using MeshDataType = MeshData<Dimension>;
+  using MeshDataType = MeshData<MeshType>;
 
   using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
   using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;
@@ -560,7 +560,7 @@ HyperelasticSolverHandler::HyperelasticSolver<MeshType>::_applyPressureBC(const
       [&](auto&& bc) {
         using T = std::decay_t<decltype(bc)>;
         if constexpr (std::is_same_v<PressureBoundaryCondition, T>) {
-          MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+          MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(mesh);
           if constexpr (Dimension == 1) {
             const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();
 
@@ -625,7 +625,7 @@ HyperelasticSolverHandler::HyperelasticSolver<MeshType>::_applyNormalStressBC(co
       [&](auto&& bc) {
         using T = std::decay_t<decltype(bc)>;
         if constexpr (std::is_same_v<NormalStressBoundaryCondition, T>) {
-          MeshData<Dimension>& mesh_data = MeshDataManager::instance().getMeshData(mesh);
+          MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(mesh);
           if constexpr (Dimension == 1) {
             const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();
 
-- 
GitLab