From 9d3382b53b8093cf4b25890850069a06fa20125f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Wed, 8 Jan 2025 18:59:07 +0100
Subject: [PATCH] Add DirichletVectorBoundaryConditionDescriptor

This kind of boundary condition type is useful for PN solvers for instance.
---
 ...chletVectorBoundaryConditionDescriptor.hpp | 58 +++++++++++++++++++
 src/scheme/IBoundaryConditionDescriptor.hpp   |  1 +
 .../IBoundaryConditionDescriptorHFType.hpp    |  1 +
 .../ReadIBoundaryConditionDescriptor.cpp      | 16 +++++
 .../WriteIBoundaryConditionDescriptor.cpp     | 13 +++++
 ...kpointing_IBoundaryConditionDescriptor.cpp | 33 +++++++++++
 6 files changed, 122 insertions(+)
 create mode 100644 src/scheme/DirichletVectorBoundaryConditionDescriptor.hpp

diff --git a/src/scheme/DirichletVectorBoundaryConditionDescriptor.hpp b/src/scheme/DirichletVectorBoundaryConditionDescriptor.hpp
new file mode 100644
index 000000000..c1f1dcc3f
--- /dev/null
+++ b/src/scheme/DirichletVectorBoundaryConditionDescriptor.hpp
@@ -0,0 +1,58 @@
+#ifndef DIRICHLET_VECTOR_BOUNDARY_CONDITION_DESCRIPTOR_HPP
+#define DIRICHLET_VECTOR_BOUNDARY_CONDITION_DESCRIPTOR_HPP
+
+#include <language/utils/FunctionSymbolId.hpp>
+#include <mesh/IBoundaryDescriptor.hpp>
+#include <scheme/BoundaryConditionDescriptorBase.hpp>
+
+#include <memory>
+#include <vector>
+
+class DirichletVectorBoundaryConditionDescriptor : public BoundaryConditionDescriptorBase
+{
+ private:
+  std::ostream&
+  _write(std::ostream& os) const final
+  {
+    os << "dirichlet_vector(" << m_name << ',' << *m_boundary_descriptor << ")";
+    return os;
+  }
+
+  const std::string m_name;
+
+  const std::vector<FunctionSymbolId> m_rhs_symbol_id_list;
+
+ public:
+  const std::string&
+  name() const
+  {
+    return m_name;
+  }
+
+  const std::vector<FunctionSymbolId>&
+  rhsSymbolIdList() const
+  {
+    return m_rhs_symbol_id_list;
+  }
+
+  Type
+  type() const final
+  {
+    return Type::dirichlet_vector;
+  }
+
+  DirichletVectorBoundaryConditionDescriptor(const std::string_view name,
+                                             std::shared_ptr<const IBoundaryDescriptor> boundary_descriptor,
+                                             const std::vector<FunctionSymbolId>& rhs_symbol_id_list)
+    : BoundaryConditionDescriptorBase{boundary_descriptor}, m_name{name}, m_rhs_symbol_id_list{rhs_symbol_id_list}
+  {
+    ;
+  }
+
+  DirichletVectorBoundaryConditionDescriptor(const DirichletVectorBoundaryConditionDescriptor&) = delete;
+  DirichletVectorBoundaryConditionDescriptor(DirichletVectorBoundaryConditionDescriptor&&)      = delete;
+
+  ~DirichletVectorBoundaryConditionDescriptor() = default;
+};
+
+#endif   // DIRICHLET_VECTOR_BOUNDARY_CONDITION_DESCRIPTOR_HPP
diff --git a/src/scheme/IBoundaryConditionDescriptor.hpp b/src/scheme/IBoundaryConditionDescriptor.hpp
index fc9ff381b..63f698e16 100644
--- a/src/scheme/IBoundaryConditionDescriptor.hpp
+++ b/src/scheme/IBoundaryConditionDescriptor.hpp
@@ -13,6 +13,7 @@ class IBoundaryConditionDescriptor
   {
     axis,
     dirichlet,
+    dirichlet_vector,
     external,
     fourier,
     fixed,
diff --git a/src/utils/checkpointing/IBoundaryConditionDescriptorHFType.hpp b/src/utils/checkpointing/IBoundaryConditionDescriptorHFType.hpp
index 210b52116..4a6e5c92c 100644
--- a/src/utils/checkpointing/IBoundaryConditionDescriptorHFType.hpp
+++ b/src/utils/checkpointing/IBoundaryConditionDescriptorHFType.hpp
@@ -10,6 +10,7 @@ create_enum_i_boundary_condition_descriptor_type()
 {
   return {{"axis", IBoundaryConditionDescriptor::Type::axis},
           {"dirichlet", IBoundaryConditionDescriptor::Type::dirichlet},
+          {"dirichlet_vector", IBoundaryConditionDescriptor::Type::dirichlet_vector},
           {"external", IBoundaryConditionDescriptor::Type::external},
           {"fixed", IBoundaryConditionDescriptor::Type::fixed},
           {"fourier", IBoundaryConditionDescriptor::Type::fourier},
diff --git a/src/utils/checkpointing/ReadIBoundaryConditionDescriptor.cpp b/src/utils/checkpointing/ReadIBoundaryConditionDescriptor.cpp
index 55b43c9f0..1f405b4d8 100644
--- a/src/utils/checkpointing/ReadIBoundaryConditionDescriptor.cpp
+++ b/src/utils/checkpointing/ReadIBoundaryConditionDescriptor.cpp
@@ -5,6 +5,7 @@
 #include <language/utils/EmbeddedData.hpp>
 #include <scheme/AxisBoundaryConditionDescriptor.hpp>
 #include <scheme/DirichletBoundaryConditionDescriptor.hpp>
+#include <scheme/DirichletVectorBoundaryConditionDescriptor.hpp>
 #include <scheme/ExternalBoundaryConditionDescriptor.hpp>
 #include <scheme/FixedBoundaryConditionDescriptor.hpp>
 #include <scheme/FourierBoundaryConditionDescriptor.hpp>
@@ -48,6 +49,21 @@ readIBoundaryConditionDescriptor(const HighFive::Group& iboundaryconditiondecrip
       std::make_shared<const DirichletBoundaryConditionDescriptor>(name, i_boundary_descriptor,
                                                                    *ResumingData::instance().functionSymbolId(rhs_id));
     break;
+  }
+  case IBoundaryConditionDescriptor::Type::dirichlet_vector: {
+    const std::string name = iboundaryconditiondecriptor_group.getAttribute("name").read<std::string>();
+
+    const std::vector function_id_list =
+      iboundaryconditiondecriptor_group.getAttribute("function_id_list").read<std::vector<size_t>>();
+
+    std::vector<FunctionSymbolId> function_symbol_id_list;
+    for (auto function_id : function_id_list) {
+      function_symbol_id_list.push_back(*ResumingData::instance().functionSymbolId(function_id));
+    }
+
+    bc_descriptor = std::make_shared<const DirichletVectorBoundaryConditionDescriptor>(name, i_boundary_descriptor,
+                                                                                       function_symbol_id_list);
+    break;
   }
     // LCOV_EXCL_START
   case IBoundaryConditionDescriptor::Type::external: {
diff --git a/src/utils/checkpointing/WriteIBoundaryConditionDescriptor.cpp b/src/utils/checkpointing/WriteIBoundaryConditionDescriptor.cpp
index 66a867058..8a3c5dd06 100644
--- a/src/utils/checkpointing/WriteIBoundaryConditionDescriptor.cpp
+++ b/src/utils/checkpointing/WriteIBoundaryConditionDescriptor.cpp
@@ -5,6 +5,7 @@
 #include <language/utils/DataHandler.hpp>
 #include <scheme/AxisBoundaryConditionDescriptor.hpp>
 #include <scheme/DirichletBoundaryConditionDescriptor.hpp>
+#include <scheme/DirichletVectorBoundaryConditionDescriptor.hpp>
 #include <scheme/FixedBoundaryConditionDescriptor.hpp>
 #include <scheme/FourierBoundaryConditionDescriptor.hpp>
 #include <scheme/FreeBoundaryConditionDescriptor.hpp>
@@ -47,6 +48,18 @@ writeIBoundaryConditionDescriptor(HighFive::Group& variable_group,
     variable_group.createAttribute("name", dirichlet_bc_descriptor.name());
     variable_group.createAttribute("rhs_function_id", dirichlet_bc_descriptor.rhsSymbolId().id());
     break;
+  }
+  case IBoundaryConditionDescriptor::Type::dirichlet_vector: {
+    const DirichletVectorBoundaryConditionDescriptor& dirichlet_vector_bc_descriptor =
+      dynamic_cast<const DirichletVectorBoundaryConditionDescriptor&>(iboundary_condition_descriptor);
+    variable_group.createAttribute("name", dirichlet_vector_bc_descriptor.name());
+    writeIBoundaryDescriptor(boundary_group, dirichlet_vector_bc_descriptor.boundaryDescriptor());
+    std::vector<size_t> function_id_list;
+    for (auto&& function_symbol_id : dirichlet_vector_bc_descriptor.rhsSymbolIdList()) {
+      function_id_list.push_back(function_symbol_id.id());
+    }
+    variable_group.createAttribute("function_id_list", function_id_list);
+    break;
   }
     // LCOV_EXCL_START
   case IBoundaryConditionDescriptor::Type::external: {
diff --git a/tests/test_checkpointing_IBoundaryConditionDescriptor.cpp b/tests/test_checkpointing_IBoundaryConditionDescriptor.cpp
index 34505db62..ab192f387 100644
--- a/tests/test_checkpointing_IBoundaryConditionDescriptor.cpp
+++ b/tests/test_checkpointing_IBoundaryConditionDescriptor.cpp
@@ -9,6 +9,7 @@
 #include <mesh/NumberedBoundaryDescriptor.hpp>
 #include <scheme/AxisBoundaryConditionDescriptor.hpp>
 #include <scheme/DirichletBoundaryConditionDescriptor.hpp>
+#include <scheme/DirichletVectorBoundaryConditionDescriptor.hpp>
 #include <scheme/ExternalBoundaryConditionDescriptor.hpp>
 #include <scheme/FixedBoundaryConditionDescriptor.hpp>
 #include <scheme/FourierBoundaryConditionDescriptor.hpp>
@@ -179,6 +180,17 @@ let i: R -> R, x -> x+3;
                                                          p_dirichlet_bc_descriptor)},
                                                        file, useless_group, symbol_table_group);
 
+      const std::vector<FunctionSymbolId> dirichlet_vector_function_id_list{FunctionSymbolId{2, symbol_table},
+                                                                            FunctionSymbolId{3, symbol_table}};
+      auto p_dirichlet_vector_bc_descriptor =
+        std::make_shared<const DirichletVectorBoundaryConditionDescriptor>("dirichlet_vector_name", p_boundary_1,
+                                                                           dirichlet_vector_function_id_list);
+      checkpointing::writeIBoundaryConditionDescriptor("dirichlet_vector_bc_descriptor",
+                                                       EmbeddedData{std::make_shared<
+                                                         DataHandler<const IBoundaryConditionDescriptor>>(
+                                                         p_dirichlet_vector_bc_descriptor)},
+                                                       file, useless_group, symbol_table_group);
+
       const FunctionSymbolId neumann_function_id{1, symbol_table};
       auto p_neumann_bc_descriptor =
         std::make_shared<const NeumannBoundaryConditionDescriptor>("neumann_name", p_boundary_1, neumann_function_id);
@@ -245,6 +257,9 @@ let i: R -> R, x -> x+3;
       EmbeddedData read_dirichlet_bc_descriptor =
         checkpointing::readIBoundaryConditionDescriptor("dirichlet_bc_descriptor", symbol_table_group);
 
+      EmbeddedData read_dirichlet_vector_bc_descriptor =
+        checkpointing::readIBoundaryConditionDescriptor("dirichlet_vector_bc_descriptor", symbol_table_group);
+
       EmbeddedData read_neumann_bc_descriptor =
         checkpointing::readIBoundaryConditionDescriptor("neumann_bc_descriptor", symbol_table_group);
 
@@ -268,6 +283,7 @@ let i: R -> R, x -> x+3;
       REQUIRE_NOTHROW(get_value(read_free_bc_descriptor));
       REQUIRE_NOTHROW(get_value(read_fixed_bc_descriptor));
       REQUIRE_NOTHROW(get_value(read_dirichlet_bc_descriptor));
+      REQUIRE_NOTHROW(get_value(read_dirichlet_vector_bc_descriptor));
       REQUIRE_NOTHROW(get_value(read_neumann_bc_descriptor));
       REQUIRE_NOTHROW(get_value(read_fourier_bc_descriptor));
       REQUIRE_NOTHROW(get_value(read_inflow_bc_descriptor));
@@ -280,6 +296,8 @@ let i: R -> R, x -> x+3;
       REQUIRE(get_value(read_free_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::free);
       REQUIRE(get_value(read_fixed_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::fixed);
       REQUIRE(get_value(read_dirichlet_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::dirichlet);
+      REQUIRE(get_value(read_dirichlet_vector_bc_descriptor).type() ==
+              IBoundaryConditionDescriptor::Type::dirichlet_vector);
       REQUIRE(get_value(read_neumann_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::neumann);
       REQUIRE(get_value(read_fourier_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::fourier);
       REQUIRE(get_value(read_inflow_bc_descriptor).type() == IBoundaryConditionDescriptor::Type::inflow);
@@ -293,6 +311,8 @@ let i: R -> R, x -> x+3;
       REQUIRE_NOTHROW(dynamic_cast<const FixedBoundaryConditionDescriptor&>(get_value(read_fixed_bc_descriptor)));
       REQUIRE_NOTHROW(
         dynamic_cast<const DirichletBoundaryConditionDescriptor&>(get_value(read_dirichlet_bc_descriptor)));
+      REQUIRE_NOTHROW(dynamic_cast<const DirichletVectorBoundaryConditionDescriptor&>(
+        get_value(read_dirichlet_vector_bc_descriptor)));
       REQUIRE_NOTHROW(dynamic_cast<const NeumannBoundaryConditionDescriptor&>(get_value(read_neumann_bc_descriptor)));
       REQUIRE_NOTHROW(dynamic_cast<const FourierBoundaryConditionDescriptor&>(get_value(read_fourier_bc_descriptor)));
       REQUIRE_NOTHROW(dynamic_cast<const InflowBoundaryConditionDescriptor&>(get_value(read_inflow_bc_descriptor)));
@@ -308,6 +328,8 @@ let i: R -> R, x -> x+3;
       auto& read_fixed_bc = dynamic_cast<const FixedBoundaryConditionDescriptor&>(get_value(read_fixed_bc_descriptor));
       auto& read_dirichlet_bc =
         dynamic_cast<const DirichletBoundaryConditionDescriptor&>(get_value(read_dirichlet_bc_descriptor));
+      auto& read_dirichlet_vector_bc =
+        dynamic_cast<const DirichletVectorBoundaryConditionDescriptor&>(get_value(read_dirichlet_vector_bc_descriptor));
       auto& read_neumann_bc =
         dynamic_cast<const NeumannBoundaryConditionDescriptor&>(get_value(read_neumann_bc_descriptor));
       auto& read_fourier_bc =
@@ -326,6 +348,15 @@ let i: R -> R, x -> x+3;
       REQUIRE(read_dirichlet_bc.boundaryDescriptor().type() == p_dirichlet_bc_descriptor->boundaryDescriptor().type());
       REQUIRE(read_dirichlet_bc.name() == p_dirichlet_bc_descriptor->name());
       REQUIRE(read_dirichlet_bc.rhsSymbolId().id() == p_dirichlet_bc_descriptor->rhsSymbolId().id());
+      REQUIRE(read_dirichlet_vector_bc.boundaryDescriptor().type() ==
+              p_dirichlet_vector_bc_descriptor->boundaryDescriptor().type());
+      REQUIRE(read_dirichlet_vector_bc.name() == p_dirichlet_vector_bc_descriptor->name());
+      REQUIRE(read_dirichlet_vector_bc.rhsSymbolIdList().size() ==
+              p_dirichlet_vector_bc_descriptor->rhsSymbolIdList().size());
+      for (size_t i = 0; i < read_dirichlet_vector_bc.rhsSymbolIdList().size(); ++i) {
+        REQUIRE(read_dirichlet_vector_bc.rhsSymbolIdList()[i].id() ==
+                p_dirichlet_vector_bc_descriptor->rhsSymbolIdList()[i].id());
+      }
       REQUIRE(read_neumann_bc.boundaryDescriptor().type() == p_neumann_bc_descriptor->boundaryDescriptor().type());
       REQUIRE(read_neumann_bc.name() == p_neumann_bc_descriptor->name());
       REQUIRE(read_neumann_bc.rhsSymbolId().id() == p_neumann_bc_descriptor->rhsSymbolId().id());
@@ -337,6 +368,8 @@ let i: R -> R, x -> x+3;
       REQUIRE(read_inflow_bc.functionSymbolId().id() == p_inflow_bc_descriptor->functionSymbolId().id());
       REQUIRE(read_inflow_list_bc.boundaryDescriptor().type() ==
               p_inflow_list_bc_descriptor->boundaryDescriptor().type());
+      REQUIRE(read_inflow_list_bc.functionSymbolIdList().size() ==
+              p_inflow_list_bc_descriptor->functionSymbolIdList().size());
       for (size_t i = 0; i < read_inflow_list_bc.functionSymbolIdList().size(); ++i) {
         REQUIRE(read_inflow_list_bc.functionSymbolIdList()[i].id() ==
                 p_inflow_list_bc_descriptor->functionSymbolIdList()[i].id());
-- 
GitLab