diff --git a/src/mesh/IMeshData.hpp b/src/mesh/IMeshData.hpp deleted file mode 100644 index 8cad9586d201de6b383bbbfaf94f09e5797078e5..0000000000000000000000000000000000000000 --- a/src/mesh/IMeshData.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#ifndef I_MESH_DATA_HPP -#define I_MESH_DATA_HPP - -class IMeshData -{ - public: - virtual ~IMeshData() = default; -}; - -#endif // I_MESH_DATA_HPP diff --git a/src/mesh/MeshData.hpp b/src/mesh/MeshData.hpp index d5eba5f5df0e29f388b3781bc8c40a84729ae46d..4fb87bc5ad4aee87c7dd92e9232ec4508145a14f 100644 --- a/src/mesh/MeshData.hpp +++ b/src/mesh/MeshData.hpp @@ -2,7 +2,6 @@ #define MESH_DATA_HPP #include <algebra/TinyVector.hpp> -#include <mesh/IMeshData.hpp> #include <mesh/ItemValue.hpp> #include <mesh/MeshTraits.hpp> #include <mesh/SubItemValuePerItem.hpp> @@ -16,7 +15,7 @@ template <MeshConcept MeshType> class MeshData; template <size_t Dimension> -class MeshData<Mesh<Dimension>> : public IMeshData +class MeshData<Mesh<Dimension>> { public: using MeshType = Mesh<Dimension>; diff --git a/src/mesh/MeshDataManager.cpp b/src/mesh/MeshDataManager.cpp index 4eb8f5677c6fae70fb3df70ade91e52a538691c1..e95a717d8f90b9aaba86f9c3702f15a06a0df54a 100644 --- a/src/mesh/MeshDataManager.cpp +++ b/src/mesh/MeshDataManager.cpp @@ -4,6 +4,7 @@ #include <mesh/Mesh.hpp> #include <mesh/MeshData.hpp> #include <mesh/MeshDataManager.hpp> +#include <mesh/MeshDataVariant.hpp> #include <utils/Exceptions.hpp> #include <sstream> @@ -45,12 +46,13 @@ MeshData<MeshType>& MeshDataManager::getMeshData(const MeshType& mesh) { if (auto i_mesh_data = m_mesh_id_mesh_data_map.find(mesh.id()); i_mesh_data != m_mesh_id_mesh_data_map.end()) { - return dynamic_cast<MeshData<MeshType>&>(*i_mesh_data->second); + const auto& mesh_data_v = *i_mesh_data->second; + return *mesh_data_v.template get<MeshType>(); } else { // **cannot** use make_shared since MeshData constructor is **private** std::shared_ptr<MeshData<MeshType>> mesh_data{new MeshData<MeshType>(mesh)}; - m_mesh_id_mesh_data_map[mesh.id()] = mesh_data; + m_mesh_id_mesh_data_map[mesh.id()] = std::make_shared<MeshDataVariant>(mesh_data); return *mesh_data; } } diff --git a/src/mesh/MeshDataManager.hpp b/src/mesh/MeshDataManager.hpp index b47756c01afe4a4ead97fba7e9f8543018989b31..4ec3bd7fe2064c726f0a86c53170f6d4e1f2dfff 100644 --- a/src/mesh/MeshDataManager.hpp +++ b/src/mesh/MeshDataManager.hpp @@ -1,7 +1,6 @@ #ifndef MESH_DATA_MANAGER_HPP #define MESH_DATA_MANAGER_HPP -#include <mesh/IMeshData.hpp> #include <mesh/MeshTraits.hpp> #include <utils/PugsAssert.hpp> #include <utils/PugsMacros.hpp> @@ -9,8 +8,7 @@ #include <memory> #include <unordered_map> -template <size_t Dimension> -class Mesh; +class MeshDataVariant; template <MeshConcept MeshType> class MeshData; @@ -18,7 +16,7 @@ class MeshData; class MeshDataManager { private: - std::unordered_map<size_t, std::shared_ptr<IMeshData>> m_mesh_id_mesh_data_map; + std::unordered_map<size_t, std::shared_ptr<MeshDataVariant>> m_mesh_id_mesh_data_map; static MeshDataManager* m_instance; diff --git a/src/mesh/MeshDataVariant.hpp b/src/mesh/MeshDataVariant.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f5e77325df2900b5fe0969f4d71347980470217 --- /dev/null +++ b/src/mesh/MeshDataVariant.hpp @@ -0,0 +1,62 @@ +#ifndef MESH_DATA_VARIANT_HPP +#define MESH_DATA_VARIANT_HPP + +#include <mesh/MeshData.hpp> +#include <mesh/MeshTraits.hpp> +#include <utils/Demangle.hpp> +#include <utils/Exceptions.hpp> +#include <utils/PugsMacros.hpp> + +#include <rang.hpp> + +#include <memory> +#include <sstream> +#include <variant> + +template <size_t Dimension> +class Mesh; + +class MeshDataVariant +{ + private: + using Variant = std::variant<std::shared_ptr<MeshData<Mesh<1>>>, // + std::shared_ptr<MeshData<Mesh<2>>>, // + std::shared_ptr<MeshData<Mesh<3>>>>; + + Variant m_p_mesh_data_variant; + + public: + template <MeshConcept MeshType> + PUGS_INLINE std::shared_ptr<MeshData<MeshType>> + get() const + { + if (not std::holds_alternative<std::shared_ptr<MeshData<MeshType>>>(this->m_p_mesh_data_variant)) { + std::ostringstream error_msg; + error_msg << "invalid mesh type type\n"; + error_msg << "- required " << rang::fgB::red << demangle<MeshData<MeshType>>() << rang::fg::reset << '\n'; + error_msg << "- contains " << rang::fgB::yellow + << std::visit( + [](auto&& p_mesh_data) -> std::string { + using FoundMeshDataType = typename std::decay_t<decltype(p_mesh_data)>::element_type; + return demangle<FoundMeshDataType>(); + }, + this->m_p_mesh_data_variant) + << rang::fg::reset; + throw NormalError(error_msg.str()); + } + return std::get<std::shared_ptr<MeshData<MeshType>>>(m_p_mesh_data_variant); + } + + MeshDataVariant() = delete; + + template <MeshConcept MeshType> + MeshDataVariant(const std::shared_ptr<MeshData<MeshType>>& p_mesh_data) : m_p_mesh_data_variant{p_mesh_data} + {} + + MeshDataVariant(const MeshDataVariant&) = default; + MeshDataVariant(MeshDataVariant&&) = default; + + ~MeshDataVariant() = default; +}; + +#endif // MESH_DATA_VARIANT_HPP diff --git a/tests/test_MeshVariant.cpp b/tests/test_MeshVariant.cpp index f03f70273b5e03709a49e18d58ce7c8536a563b1..212bcecd003f1842dd6e769451172b70d0871d36 100644 --- a/tests/test_MeshVariant.cpp +++ b/tests/test_MeshVariant.cpp @@ -9,7 +9,7 @@ TEST_CASE("MeshVariant", "[mesh]") { - SECTION("1D") + SECTION("Polygonal 1D") { auto mesh_v = MeshDataBaseForTests::get().unordered1DMesh(); auto mesh = mesh_v->get<Mesh<1>>(); @@ -27,6 +27,8 @@ TEST_CASE("MeshVariant", "[mesh]") REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges()); REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes()); + REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id()); + { std::ostringstream os_v; os_v << *mesh_v; @@ -38,7 +40,7 @@ TEST_CASE("MeshVariant", "[mesh]") } } - SECTION("2D") + SECTION("Polygonal 2D") { auto mesh_v = MeshDataBaseForTests::get().hybrid2DMesh(); auto mesh = mesh_v->get<Mesh<2>>(); @@ -56,6 +58,8 @@ TEST_CASE("MeshVariant", "[mesh]") REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges()); REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes()); + REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id()); + { std::ostringstream os_v; os_v << *mesh_v; @@ -67,7 +71,7 @@ TEST_CASE("MeshVariant", "[mesh]") } } - SECTION("3D") + SECTION("Polygonal 3D") { auto mesh_v = MeshDataBaseForTests::get().hybrid3DMesh(); auto mesh = mesh_v->get<Mesh<3>>(); @@ -85,6 +89,8 @@ TEST_CASE("MeshVariant", "[mesh]") REQUIRE(mesh->numberOfEdges() == mesh_v->numberOfEdges()); REQUIRE(mesh->numberOfNodes() == mesh_v->numberOfNodes()); + REQUIRE(mesh->connectivity().id() == mesh_v->connectivity().id()); + { std::ostringstream os_v; os_v << *mesh_v;