From f4418b15e1cc092622c1d678e90c3010af074ccd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Thu, 7 Mar 2024 17:26:24 +0100
Subject: [PATCH] Replace IMeshData by MeshDataVariant

---
 src/mesh/IMeshData.hpp       | 10 ------
 src/mesh/MeshData.hpp        |  3 +-
 src/mesh/MeshDataManager.cpp |  6 ++--
 src/mesh/MeshDataManager.hpp |  6 ++--
 src/mesh/MeshDataVariant.hpp | 62 ++++++++++++++++++++++++++++++++++++
 tests/test_MeshVariant.cpp   | 12 +++++--
 6 files changed, 78 insertions(+), 21 deletions(-)
 delete mode 100644 src/mesh/IMeshData.hpp
 create mode 100644 src/mesh/MeshDataVariant.hpp

diff --git a/src/mesh/IMeshData.hpp b/src/mesh/IMeshData.hpp
deleted file mode 100644
index 8cad9586d..000000000
--- a/src/mesh/IMeshData.hpp
+++ /dev/null
@@ -1,10 +0,0 @@
-#ifndef I_MESH_DATA_HPP
-#define I_MESH_DATA_HPP
-
-class IMeshData
-{
- public:
-  virtual ~IMeshData() = default;
-};
-
-#endif   // I_MESH_DATA_HPP
diff --git a/src/mesh/MeshData.hpp b/src/mesh/MeshData.hpp
index d5eba5f5d..4fb87bc5a 100644
--- a/src/mesh/MeshData.hpp
+++ b/src/mesh/MeshData.hpp
@@ -2,7 +2,6 @@
 #define MESH_DATA_HPP
 
 #include <algebra/TinyVector.hpp>
-#include <mesh/IMeshData.hpp>
 #include <mesh/ItemValue.hpp>
 #include <mesh/MeshTraits.hpp>
 #include <mesh/SubItemValuePerItem.hpp>
@@ -16,7 +15,7 @@ template <MeshConcept MeshType>
 class MeshData;
 
 template <size_t Dimension>
-class MeshData<Mesh<Dimension>> : public IMeshData
+class MeshData<Mesh<Dimension>>
 {
  public:
   using MeshType = Mesh<Dimension>;
diff --git a/src/mesh/MeshDataManager.cpp b/src/mesh/MeshDataManager.cpp
index 4eb8f5677..e95a717d8 100644
--- a/src/mesh/MeshDataManager.cpp
+++ b/src/mesh/MeshDataManager.cpp
@@ -4,6 +4,7 @@
 #include <mesh/Mesh.hpp>
 #include <mesh/MeshData.hpp>
 #include <mesh/MeshDataManager.hpp>
+#include <mesh/MeshDataVariant.hpp>
 #include <utils/Exceptions.hpp>
 
 #include <sstream>
@@ -45,12 +46,13 @@ MeshData<MeshType>&
 MeshDataManager::getMeshData(const MeshType& mesh)
 {
   if (auto i_mesh_data = m_mesh_id_mesh_data_map.find(mesh.id()); i_mesh_data != m_mesh_id_mesh_data_map.end()) {
-    return dynamic_cast<MeshData<MeshType>&>(*i_mesh_data->second);
+    const auto& mesh_data_v = *i_mesh_data->second;
+    return *mesh_data_v.template get<MeshType>();
   } else {
     // **cannot** use make_shared since MeshData constructor is **private**
     std::shared_ptr<MeshData<MeshType>> mesh_data{new MeshData<MeshType>(mesh)};
 
-    m_mesh_id_mesh_data_map[mesh.id()] = mesh_data;
+    m_mesh_id_mesh_data_map[mesh.id()] = std::make_shared<MeshDataVariant>(mesh_data);
     return *mesh_data;
   }
 }
diff --git a/src/mesh/MeshDataManager.hpp b/src/mesh/MeshDataManager.hpp
index b47756c01..4ec3bd7fe 100644
--- a/src/mesh/MeshDataManager.hpp
+++ b/src/mesh/MeshDataManager.hpp
@@ -1,7 +1,6 @@
 #ifndef MESH_DATA_MANAGER_HPP
 #define MESH_DATA_MANAGER_HPP
 
-#include <mesh/IMeshData.hpp>
 #include <mesh/MeshTraits.hpp>
 #include <utils/PugsAssert.hpp>
 #include <utils/PugsMacros.hpp>
@@ -9,8 +8,7 @@
 #include <memory>
 #include <unordered_map>
 
-template <size_t Dimension>
-class Mesh;
+class MeshDataVariant;
 
 template <MeshConcept MeshType>
 class MeshData;
@@ -18,7 +16,7 @@ class MeshData;
 class MeshDataManager
 {
  private:
-  std::unordered_map<size_t, std::shared_ptr<IMeshData>> m_mesh_id_mesh_data_map;
+  std::unordered_map<size_t, std::shared_ptr<MeshDataVariant>> m_mesh_id_mesh_data_map;
 
   static MeshDataManager* m_instance;
 
diff --git a/src/mesh/MeshDataVariant.hpp b/src/mesh/MeshDataVariant.hpp
new file mode 100644
index 000000000..1f5e77325
--- /dev/null
+++ b/src/mesh/MeshDataVariant.hpp
@@ -0,0 +1,62 @@
+#ifndef MESH_DATA_VARIANT_HPP
+#define MESH_DATA_VARIANT_HPP
+
+#include <mesh/MeshData.hpp>
+#include <mesh/MeshTraits.hpp>
+#include <utils/Demangle.hpp>
+#include <utils/Exceptions.hpp>
+#include <utils/PugsMacros.hpp>
+
+#include <rang.hpp>
+
+#include <memory>
+#include <sstream>
+#include <variant>
+
+template <size_t Dimension>
+class Mesh;
+
+class MeshDataVariant
+{
+ private:
+  using Variant = std::variant<std::shared_ptr<MeshData<Mesh<1>>>,   //
+                               std::shared_ptr<MeshData<Mesh<2>>>,   //
+                               std::shared_ptr<MeshData<Mesh<3>>>>;
+
+  Variant m_p_mesh_data_variant;
+
+ public:
+  template <MeshConcept MeshType>
+  PUGS_INLINE std::shared_ptr<MeshData<MeshType>>
+  get() const
+  {
+    if (not std::holds_alternative<std::shared_ptr<MeshData<MeshType>>>(this->m_p_mesh_data_variant)) {
+      std::ostringstream error_msg;
+      error_msg << "invalid mesh type type\n";
+      error_msg << "- required " << rang::fgB::red << demangle<MeshData<MeshType>>() << rang::fg::reset << '\n';
+      error_msg << "- contains " << rang::fgB::yellow
+                << std::visit(
+                     [](auto&& p_mesh_data) -> std::string {
+                       using FoundMeshDataType = typename std::decay_t<decltype(p_mesh_data)>::element_type;
+                       return demangle<FoundMeshDataType>();
+                     },
+                     this->m_p_mesh_data_variant)
+                << rang::fg::reset;
+      throw NormalError(error_msg.str());
+    }
+    return std::get<std::shared_ptr<MeshData<MeshType>>>(m_p_mesh_data_variant);
+  }
+
+  MeshDataVariant() = delete;
+
+  template <MeshConcept MeshType>
+  MeshDataVariant(const std::shared_ptr<MeshData<MeshType>>& p_mesh_data) : m_p_mesh_data_variant{p_mesh_data}
+  {}
+
+  MeshDataVariant(const MeshDataVariant&) = default;
+  MeshDataVariant(MeshDataVariant&&)      = default;
+
+  ~MeshDataVariant() = default;
+};
+
+#endif   // MESH_DATA_VARIANT_HPP
diff --git a/tests/test_MeshVariant.cpp b/tests/test_MeshVariant.cpp
index f03f70273..212bcecd0 100644
--- a/tests/test_MeshVariant.cpp
+++ b/tests/test_MeshVariant.cpp
@@ -9,7 +9,7 @@
 
 TEST_CASE("MeshVariant", "[mesh]")
 {
-  SECTION("1D")
+  SECTION("Polygonal 1D")
   {
     auto mesh_v = MeshDataBaseForTests::get().unordered1DMesh();
     auto mesh   = mesh_v->get<Mesh<1>>();
@@ -27,6 +27,8 @@ TEST_CASE("MeshVariant", "[mesh]")
     REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges());
     REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes());
 
+    REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id());
+
     {
       std::ostringstream os_v;
       os_v << *mesh_v;
@@ -38,7 +40,7 @@ TEST_CASE("MeshVariant", "[mesh]")
     }
   }
 
-  SECTION("2D")
+  SECTION("Polygonal 2D")
   {
     auto mesh_v = MeshDataBaseForTests::get().hybrid2DMesh();
     auto mesh   = mesh_v->get<Mesh<2>>();
@@ -56,6 +58,8 @@ TEST_CASE("MeshVariant", "[mesh]")
     REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges());
     REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes());
 
+    REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id());
+
     {
       std::ostringstream os_v;
       os_v << *mesh_v;
@@ -67,7 +71,7 @@ TEST_CASE("MeshVariant", "[mesh]")
     }
   }
 
-  SECTION("3D")
+  SECTION("Polygonal 3D")
   {
     auto mesh_v = MeshDataBaseForTests::get().hybrid3DMesh();
     auto mesh   = mesh_v->get<Mesh<3>>();
@@ -85,6 +89,8 @@ TEST_CASE("MeshVariant", "[mesh]")
     REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges());
     REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes());
 
+    REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id());
+
     {
       std::ostringstream os_v;
       os_v << *mesh_v;
-- 
GitLab