From a61e7780802bbb2e66769ce2028490c09e417e70 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Tue, 22 Oct 2024 11:16:38 +0200
Subject: [PATCH] Generalize managed stencil types and begin new stencil
 building

---
 src/mesh/StencilBuilder.cpp | 37 ++++++++++++++++++---------
 src/mesh/StencilBuilder.hpp | 47 +++++++++++++++++++++++++++++-----
 src/mesh/StencilManager.cpp | 50 +++++++++++++++++++++++++++++++------
 src/mesh/StencilManager.hpp |  7 ++++++
 4 files changed, 115 insertions(+), 26 deletions(-)

diff --git a/src/mesh/StencilBuilder.cpp b/src/mesh/StencilBuilder.cpp
index 8af6a76ae..d70ea9534 100644
--- a/src/mesh/StencilBuilder.cpp
+++ b/src/mesh/StencilBuilder.cpp
@@ -110,9 +110,9 @@ StencilBuilder::_getColumnIndices(const ConnectivityType& connectivity, const Ar
 
 template <typename ConnectivityType>
 CellToCellStencilArray
-StencilBuilder::_build(const ConnectivityType& connectivity,
-                       size_t number_of_layers,
-                       const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
+StencilBuilder::_buildC2C(const ConnectivityType& connectivity,
+                          size_t number_of_layers,
+                          const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
 {
   if ((parallel::size() > 1) and (number_of_layers > GlobalVariableManager::instance().getNumberOfGhostLayers())) {
     std::ostringstream error_msg;
@@ -122,6 +122,7 @@ StencilBuilder::_build(const ConnectivityType& connectivity,
     error_msg << "Increase the number of ghost layers (using the '--number-of-ghost-layers' option).";
     throw NormalError(error_msg.str());
   }
+
   if (number_of_layers > 2) {
     throw NotImplementedError("number of layers too large");
   }
@@ -340,25 +341,37 @@ StencilBuilder::_build(const ConnectivityType& connectivity,
 }
 
 CellToCellStencilArray
-StencilBuilder::build(const IConnectivity& connectivity,
-                      const StencilDescriptor& stencil_descriptor,
-                      const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
+StencilBuilder::buildC2C(const IConnectivity& connectivity,
+                         const StencilDescriptor& stencil_descriptor,
+                         const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
 {
   switch (connectivity.dimension()) {
   case 1: {
-    return StencilBuilder::_build(dynamic_cast<const Connectivity<1>&>(connectivity),
-                                  stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
+    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<1>&>(connectivity),
+                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
   }
   case 2: {
-    return StencilBuilder::_build(dynamic_cast<const Connectivity<2>&>(connectivity),
-                                  stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
+    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<2>&>(connectivity),
+                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
   }
   case 3: {
-    return StencilBuilder::_build(dynamic_cast<const Connectivity<3>&>(connectivity),
-                                  stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
+    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<3>&>(connectivity),
+                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
   }
   default: {
     throw UnexpectedError("invalid connectivity dimension");
   }
   }
 }
+
+CellToFaceStencilArray
+StencilBuilder::buildC2F(const IConnectivity&, const StencilDescriptor&, const BoundaryDescriptorList&) const
+{
+  throw NotImplementedError("cell to face stencil");
+}
+
+NodeToCellStencilArray
+StencilBuilder::buildN2C(const IConnectivity&, const StencilDescriptor&, const BoundaryDescriptorList&) const
+{
+  throw NotImplementedError("node to cell stencil");
+}
diff --git a/src/mesh/StencilBuilder.hpp b/src/mesh/StencilBuilder.hpp
index 8b0cc09d1..b930ca4cc 100644
--- a/src/mesh/StencilBuilder.hpp
+++ b/src/mesh/StencilBuilder.hpp
@@ -23,14 +23,49 @@ class StencilBuilder
                                           const Array<const uint32_t>& row_map) const;
 
   template <typename ConnectivityType>
-  CellToCellStencilArray _build(const ConnectivityType& connectivity,
-                                size_t number_of_layers,
-                                const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+  CellToCellStencilArray _buildC2C(const ConnectivityType& connectivity,
+                                   size_t number_of_layers,
+                                   const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+
+  template <typename ConnectivityType>
+  CellToFaceStencilArray _buildC2F(const ConnectivityType& connectivity,
+                                   size_t number_of_layers,
+                                   const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+
+  template <typename ConnectivityType>
+  NodeToCellStencilArray _buildN2C(const ConnectivityType& connectivity,
+                                   size_t number_of_layers,
+                                   const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+
+  CellToCellStencilArray buildC2C(const IConnectivity& connectivity,
+                                  const StencilDescriptor& stencil_descriptor,
+                                  const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+
+  CellToFaceStencilArray buildC2F(const IConnectivity& connectivity,
+                                  const StencilDescriptor& stencil_descriptor,
+                                  const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+
+  NodeToCellStencilArray buildN2C(const IConnectivity& connectivity,
+                                  const StencilDescriptor& stencil_descriptor,
+                                  const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
 
   friend class StencilManager;
-  CellToCellStencilArray build(const IConnectivity& connectivity,
-                               const StencilDescriptor& stencil_descriptor,
-                               const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const;
+  template <ItemType source_item_type, ItemType target_item_type>
+  StencilArray<source_item_type, target_item_type>
+  build(const IConnectivity& connectivity,
+        const StencilDescriptor& stencil_descriptor,
+        const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
+  {
+    if constexpr ((source_item_type == ItemType::cell) and (target_item_type == ItemType::cell)) {
+      return buildC2C(connectivity, stencil_descriptor, symmetry_boundary_descriptor_list);
+    } else if constexpr ((source_item_type == ItemType::cell) and (target_item_type == ItemType::face)) {
+      return buildC2F(connectivity, stencil_descriptor, symmetry_boundary_descriptor_list);
+    } else if constexpr ((source_item_type == ItemType::node) and (target_item_type == ItemType::cell)) {
+      return buildN2C(connectivity, stencil_descriptor, symmetry_boundary_descriptor_list);
+    } else {
+      static_assert(is_false_item_type_v<source_item_type>, "invalid stencil type");
+    }
+  }
 
  public:
   StencilBuilder()                      = default;
diff --git a/src/mesh/StencilManager.cpp b/src/mesh/StencilManager.cpp
index 6de72c08a..ef1d9ef1f 100644
--- a/src/mesh/StencilManager.cpp
+++ b/src/mesh/StencilManager.cpp
@@ -38,20 +38,54 @@ StencilManager::destroy()
   m_instance = nullptr;
 }
 
+template <ItemType source_item_type, ItemType target_item_type>
+const StencilArray<source_item_type, target_item_type>&
+StencilManager::_getStencilArray(
+  const IConnectivity& connectivity,
+  const StencilDescriptor& stencil_descriptor,
+  const BoundaryDescriptorList& symmetry_boundary_descriptor_list,
+  StoredStencilTMap<source_item_type, target_item_type>& stored_source_to_target_stencil_map)
+{
+  if (not stored_source_to_target_stencil_map.contains(
+        Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list})) {
+    stored_source_to_target_stencil_map[Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list}] =
+      std::make_shared<StencilArray<source_item_type, target_item_type>>(
+        StencilBuilder{}.template build<source_item_type, target_item_type>(connectivity, stencil_descriptor,
+                                                                            symmetry_boundary_descriptor_list));
+  }
+
+  return *stored_source_to_target_stencil_map.at(
+    Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list});
+}
+
 const CellToCellStencilArray&
 StencilManager::getCellToCellStencilArray(const IConnectivity& connectivity,
                                           const StencilDescriptor& stencil_descriptor,
                                           const BoundaryDescriptorList& symmetry_boundary_descriptor_list)
 {
-  if (not m_stored_cell_to_cell_stencil_map.contains(
-        Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list})) {
-    m_stored_cell_to_cell_stencil_map[Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list}] =
-      std::make_shared<CellToCellStencilArray>(
-        StencilBuilder{}.build(connectivity, stencil_descriptor, symmetry_boundary_descriptor_list));
-  }
+  return this->_getStencilArray<ItemType::cell, ItemType::cell>(connectivity, stencil_descriptor,
+                                                                symmetry_boundary_descriptor_list,
+                                                                m_stored_cell_to_cell_stencil_map);
+}
 
-  return *m_stored_cell_to_cell_stencil_map.at(
-    Key{connectivity.id(), stencil_descriptor, symmetry_boundary_descriptor_list});
+const CellToFaceStencilArray&
+StencilManager::getCellToFaceStencilArray(const IConnectivity& connectivity,
+                                          const StencilDescriptor& stencil_descriptor,
+                                          const BoundaryDescriptorList& symmetry_boundary_descriptor_list)
+{
+  return this->_getStencilArray<ItemType::cell, ItemType::face>(connectivity, stencil_descriptor,
+                                                                symmetry_boundary_descriptor_list,
+                                                                m_stored_cell_to_face_stencil_map);
+}
+
+const NodeToCellStencilArray&
+StencilManager::getNodeToCellStencilArray(const IConnectivity& connectivity,
+                                          const StencilDescriptor& stencil_descriptor,
+                                          const BoundaryDescriptorList& symmetry_boundary_descriptor_list)
+{
+  return this->_getStencilArray<ItemType::node, ItemType::cell>(connectivity, stencil_descriptor,
+                                                                symmetry_boundary_descriptor_list,
+                                                                m_stored_node_to_cell_stencil_map);
 }
 
 void
diff --git a/src/mesh/StencilManager.hpp b/src/mesh/StencilManager.hpp
index 6deb86536..c68e7fe06 100644
--- a/src/mesh/StencilManager.hpp
+++ b/src/mesh/StencilManager.hpp
@@ -73,6 +73,13 @@ class StencilManager
   StoredStencilTMap<ItemType::cell, ItemType::face> m_stored_cell_to_face_stencil_map;
   StoredStencilTMap<ItemType::node, ItemType::cell> m_stored_node_to_cell_stencil_map;
 
+  template <ItemType source_item_type, ItemType target_item_type>
+  const StencilArray<source_item_type, target_item_type>& _getStencilArray(
+    const IConnectivity& connectivity,
+    const StencilDescriptor& stencil_descriptor,
+    const BoundaryDescriptorList& symmetry_boundary_descriptor_list,
+    StoredStencilTMap<source_item_type, target_item_type>& stored_source_to_target_stencil_map);
+
  public:
   static void create();
   static void destroy();
-- 
GitLab