From 79beaa051e8b00d39aacec7fca030fe0435bf738 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Tue, 19 Nov 2024 15:13:00 +0100
Subject: [PATCH] Allow TinyVector and TinyMatrix as data type for
 DiscreteFunctionP0Vector

---
 .../EmbeddedDiscreteFunctionMathFunctions.cpp | 12 ++--
 src/output/OutputNamedItemValueSet.hpp        | 28 +++++++++-
 src/output/VTKWriter.cpp                      | 56 +++++++++++++++----
 src/scheme/DiscreteFunctionP0Vector.hpp       | 16 ++++--
 src/scheme/DiscreteFunctionVariant.hpp        | 16 +++++-
 src/scheme/FluxingAdvectionSolver.cpp         | 17 ++++--
 src/scheme/PolynomialReconstruction.cpp       | 51 ++++++++++-------
 7 files changed, 145 insertions(+), 51 deletions(-)

diff --git a/src/language/utils/EmbeddedDiscreteFunctionMathFunctions.cpp b/src/language/utils/EmbeddedDiscreteFunctionMathFunctions.cpp
index 5d000813a..892037219 100644
--- a/src/language/utils/EmbeddedDiscreteFunctionMathFunctions.cpp
+++ b/src/language/utils/EmbeddedDiscreteFunctionMathFunctions.cpp
@@ -234,10 +234,14 @@ dot(const std::shared_ptr<const DiscreteFunctionVariant>& f_v,
             throw NormalError(EmbeddedDiscreteFunctionUtils::invalidOperandType(f));
           }
         } else if constexpr (is_discrete_function_P0_vector_v<TypeOfF>) {
-          if (f.size() == g.size()) {
-            return std::make_shared<DiscreteFunctionVariant>(dot(f, g));
+          if constexpr (std::is_arithmetic_v<DataType>) {
+            if (f.size() == g.size()) {
+              return std::make_shared<DiscreteFunctionVariant>(dot(f, g));
+            } else {
+              throw NormalError("operands have different dimension");
+            }
           } else {
-            throw NormalError("operands have different dimension");
+            throw NormalError(EmbeddedDiscreteFunctionUtils::invalidOperandType(f));
           }
         } else {
           throw NormalError(EmbeddedDiscreteFunctionUtils::invalidOperandType(f));
@@ -691,8 +695,6 @@ sum_of_Vh_components(const std::shared_ptr<const DiscreteFunctionVariant>& f)
     [&](auto&& discrete_function) -> std::shared_ptr<const DiscreteFunctionVariant> {
       using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
       if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
-        using DataType = std::decay_t<typename DiscreteFunctionT::data_type>;
-        static_assert(std::is_same_v<DataType, double>);
         return std::make_shared<DiscreteFunctionVariant>(sumOfComponents(discrete_function));
       } else {
         throw NormalError(EmbeddedDiscreteFunctionUtils::invalidOperandType(f));
diff --git a/src/output/OutputNamedItemValueSet.hpp b/src/output/OutputNamedItemValueSet.hpp
index aa2c5e2ed..8c4333042 100644
--- a/src/output/OutputNamedItemValueSet.hpp
+++ b/src/output/OutputNamedItemValueSet.hpp
@@ -38,7 +38,7 @@ class NamedItemData
   }
 
   NamedItemData& operator=(const NamedItemData&) = default;
-  NamedItemData& operator=(NamedItemData&&) = default;
+  NamedItemData& operator=(NamedItemData&&)      = default;
 
   NamedItemData(const std::string& name, const ItemDataT<DataType, item_type, ConnectivityPtr>& item_data)
     : m_name(name), m_item_data(item_data)
@@ -113,24 +113,48 @@ class OutputNamedItemDataSet
                                        NodeArray<const long int>,
                                        NodeArray<const unsigned long int>,
                                        NodeArray<const double>,
+                                       NodeArray<const TinyVector<1, double>>,
+                                       NodeArray<const TinyVector<2, double>>,
+                                       NodeArray<const TinyVector<3, double>>,
+                                       NodeArray<const TinyMatrix<1, 1, double>>,
+                                       NodeArray<const TinyMatrix<2, 2, double>>,
+                                       NodeArray<const TinyMatrix<3, 3, double>>,
 
                                        EdgeArray<const bool>,
                                        EdgeArray<const int>,
                                        EdgeArray<const long int>,
                                        EdgeArray<const unsigned long int>,
                                        EdgeArray<const double>,
+                                       EdgeArray<const TinyVector<1, double>>,
+                                       EdgeArray<const TinyVector<2, double>>,
+                                       EdgeArray<const TinyVector<3, double>>,
+                                       EdgeArray<const TinyMatrix<1, 1, double>>,
+                                       EdgeArray<const TinyMatrix<2, 2, double>>,
+                                       EdgeArray<const TinyMatrix<3, 3, double>>,
 
                                        FaceArray<const bool>,
                                        FaceArray<const int>,
                                        FaceArray<const long int>,
                                        FaceArray<const unsigned long int>,
                                        FaceArray<const double>,
+                                       FaceArray<const TinyVector<1, double>>,
+                                       FaceArray<const TinyVector<2, double>>,
+                                       FaceArray<const TinyVector<3, double>>,
+                                       FaceArray<const TinyMatrix<1, 1, double>>,
+                                       FaceArray<const TinyMatrix<2, 2, double>>,
+                                       FaceArray<const TinyMatrix<3, 3, double>>,
 
                                        CellArray<const bool>,
                                        CellArray<const int>,
                                        CellArray<const long int>,
                                        CellArray<const unsigned long int>,
-                                       CellArray<const double>>;
+                                       CellArray<const double>,
+                                       CellArray<const TinyVector<1, double>>,
+                                       CellArray<const TinyVector<2, double>>,
+                                       CellArray<const TinyVector<3, double>>,
+                                       CellArray<const TinyMatrix<1, 1, double>>,
+                                       CellArray<const TinyMatrix<2, 2, double>>,
+                                       CellArray<const TinyMatrix<3, 3, double>>>;
 
  private:
   // We do not use a map, we want variables to be written in the
diff --git a/src/output/VTKWriter.cpp b/src/output/VTKWriter.cpp
index 64d354344..2aedce3b1 100644
--- a/src/output/VTKWriter.cpp
+++ b/src/output/VTKWriter.cpp
@@ -396,10 +396,17 @@ VTKWriter::_write(const MeshType& mesh,
          << "\">\n";
     fout << "<CellData>\n";
     for (const auto& [name, item_value_variant] : output_named_item_data_set) {
-      std::visit([&, var_name = name](
-                   auto&&
-                     item_value) { return this->_write_cell_data(fout, var_name, item_value, serialize_data_list); },
-                 item_value_variant);
+      std::visit(
+        [&, var_name = name](auto&& item_value) {
+          using IVType   = std::decay_t<decltype(item_value)>;
+          using DataType = typename IVType::data_type;
+          if constexpr (is_item_array_v<IVType> and not std::is_arithmetic_v<DataType>) {
+            throw NotImplementedError("DiscreteFunctionP0Vector of non arithmetic type");
+          } else {
+            return this->_write_cell_data(fout, var_name, item_value, serialize_data_list);
+          }
+        },
+        item_value_variant);
     }
     if (parallel::size() > 1) {
       CellValue<uint8_t> vtk_ghost_type{mesh.connectivity()};
@@ -413,10 +420,17 @@ VTKWriter::_write(const MeshType& mesh,
     fout << "</CellData>\n";
     fout << "<PointData>\n";
     for (const auto& [name, item_value_variant] : output_named_item_data_set) {
-      std::visit([&, var_name = name](
-                   auto&&
-                     item_value) { return this->_write_node_data(fout, var_name, item_value, serialize_data_list); },
-                 item_value_variant);
+      std::visit(
+        [&, var_name = name](auto&& item_value) {
+          using IVType   = std::decay_t<decltype(item_value)>;
+          using DataType = typename IVType::data_type;
+          if constexpr (is_item_array_v<IVType> and not std::is_arithmetic_v<DataType>) {
+            throw NotImplementedError("DiscreteFunctionP0Vector of non arithmetic type");
+          } else {
+            return this->_write_node_data(fout, var_name, item_value, serialize_data_list);
+          }
+        },
+        item_value_variant);
     }
     fout << "</PointData>\n";
     fout << "<Points>\n";
@@ -644,15 +658,33 @@ VTKWriter::_write(const MeshType& mesh,
 
     fout << "<PPointData>\n";
     for (const auto& [name, item_value_variant] : output_named_item_data_set) {
-      std::visit([&, var_name = name](auto&& item_value) { return this->_write_node_pvtu(fout, var_name, item_value); },
-                 item_value_variant);
+      std::visit(
+        [&, var_name = name](auto&& item_value) {
+          using IVType   = std::decay_t<decltype(item_value)>;
+          using DataType = typename IVType::data_type;
+          if constexpr (is_item_array_v<IVType> and not std::is_arithmetic_v<DataType>) {
+            throw NotImplementedError("DiscreteFunctionP0Vector of non arithmetic type");
+          } else {
+            return this->_write_node_pvtu(fout, var_name, item_value);
+          }
+        },
+        item_value_variant);
     }
     fout << "</PPointData>\n";
 
     fout << "<PCellData>\n";
     for (const auto& [name, item_value_variant] : output_named_item_data_set) {
-      std::visit([&, var_name = name](auto&& item_value) { return this->_write_cell_pvtu(fout, var_name, item_value); },
-                 item_value_variant);
+      std::visit(
+        [&, var_name = name](auto&& item_value) {
+          using IVType   = std::decay_t<decltype(item_value)>;
+          using DataType = typename IVType::data_type;
+          if constexpr (is_item_array_v<IVType> and not std::is_arithmetic_v<DataType>) {
+            throw NotImplementedError("DiscreteFunctionP0Vector of non arithmetic type");
+          } else {
+            return this->_write_cell_pvtu(fout, var_name, item_value);
+          }
+        },
+        item_value_variant);
     }
     if (parallel::size() > 1) {
       fout << "<PDataArray type=\"UInt8\" Name=\"vtkGhostType\" NumberOfComponents=\"1\"/>\n";
diff --git a/src/scheme/DiscreteFunctionP0Vector.hpp b/src/scheme/DiscreteFunctionP0Vector.hpp
index da3904b2a..bc08cb2a7 100644
--- a/src/scheme/DiscreteFunctionP0Vector.hpp
+++ b/src/scheme/DiscreteFunctionP0Vector.hpp
@@ -19,8 +19,6 @@ class DiscreteFunctionP0Vector
   friend class DiscreteFunctionP0Vector<std::add_const_t<DataType>>;
   friend class DiscreteFunctionP0Vector<std::remove_const_t<DataType>>;
 
-  static_assert(std::is_arithmetic_v<DataType>, "DiscreteFunctionP0Vector are only defined for arithmetic data type");
-
  private:
   std::shared_ptr<const MeshVariant> m_mesh;
   CellArray<DataType> m_cell_arrays;
@@ -188,16 +186,23 @@ class DiscreteFunctionP0Vector
     return product;
   }
 
-  PUGS_INLINE friend DiscreteFunctionP0<double>
+  PUGS_INLINE friend DiscreteFunctionP0<std::remove_const_t<DataType>>
   sumOfComponents(const DiscreteFunctionP0Vector& f)
   {
-    DiscreteFunctionP0<double> result{f.m_mesh};
+    DiscreteFunctionP0<std::remove_const_t<DataType>> result{f.m_mesh};
 
     parallel_for(
       f.m_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
         const auto& f_cell_id = f[cell_id];
 
-        double sum = 0;
+        std::remove_const_t<DataType> sum = [] {
+          if constexpr (std::is_arithmetic_v<DataType>) {
+            return 0;
+          } else {
+            return zero;
+          }
+        }();
+
         for (size_t i = 0; i < f.size(); ++i) {
           sum += f_cell_id[i];
         }
@@ -213,6 +218,7 @@ class DiscreteFunctionP0Vector
   {
     Assert(f.meshVariant()->id() == g.meshVariant()->id(), "functions are nor defined on the same mesh");
     Assert(f.size() == g.size());
+    static_assert(std::is_arithmetic_v<std::decay_t<DataType>>);
     DiscreteFunctionP0<double> result{f.m_mesh};
     parallel_for(
       f.m_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
diff --git a/src/scheme/DiscreteFunctionVariant.hpp b/src/scheme/DiscreteFunctionVariant.hpp
index 901b79890..616184298 100644
--- a/src/scheme/DiscreteFunctionVariant.hpp
+++ b/src/scheme/DiscreteFunctionVariant.hpp
@@ -19,7 +19,13 @@ class DiscreteFunctionVariant
                                DiscreteFunctionP0<const TinyMatrix<2>>,
                                DiscreteFunctionP0<const TinyMatrix<3>>,
 
-                               DiscreteFunctionP0Vector<const double>>;
+                               DiscreteFunctionP0Vector<const double>,
+                               DiscreteFunctionP0Vector<const TinyVector<1>>,
+                               DiscreteFunctionP0Vector<const TinyVector<2>>,
+                               DiscreteFunctionP0Vector<const TinyVector<3>>,
+                               DiscreteFunctionP0Vector<const TinyMatrix<1>>,
+                               DiscreteFunctionP0Vector<const TinyMatrix<2>>,
+                               DiscreteFunctionP0Vector<const TinyMatrix<3>>>;
 
   Variant m_discrete_function;
 
@@ -70,7 +76,13 @@ class DiscreteFunctionVariant
   DiscreteFunctionVariant(const DiscreteFunctionP0Vector<DataType>& discrete_function)
     : m_discrete_function{DiscreteFunctionP0Vector<const DataType>{discrete_function}}
   {
-    static_assert(std::is_same_v<std::remove_const_t<DataType>, double>,
+    static_assert(std::is_same_v<std::remove_const_t<DataType>, double> or                       //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyVector<1, double>> or      //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyVector<2, double>> or      //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyVector<3, double>> or      //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyMatrix<1, 1, double>> or   //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyMatrix<2, 2, double>> or   //
+                    std::is_same_v<std::remove_const_t<DataType>, TinyMatrix<3, 3, double>>,
                   "DiscreteFunctionP0Vector with this DataType is not allowed in variant");
   }
 
diff --git a/src/scheme/FluxingAdvectionSolver.cpp b/src/scheme/FluxingAdvectionSolver.cpp
index d57eae2c4..6f85610e0 100644
--- a/src/scheme/FluxingAdvectionSolver.cpp
+++ b/src/scheme/FluxingAdvectionSolver.cpp
@@ -86,10 +86,15 @@ class FluxingAdvectionSolver
     m_remapped_list.emplace_back(copy(old_q.cellValues()));
   }
 
+  template <typename DataType>
   void
-  _storeValues(const DiscreteFunctionP0Vector<const double>& old_q)
+  _storeValues(const DiscreteFunctionP0Vector<const DataType>& old_q)
   {
-    m_remapped_list.emplace_back(copy(old_q.cellArrays()));
+    if constexpr (std::is_arithmetic_v<DataType>) {
+      m_remapped_list.emplace_back(copy(old_q.cellArrays()));
+    } else {
+      throw NormalError("remapping DiscreteFunctionP0Vector of non arithmetic data type is not supported");
+    }
   }
 
   template <typename DataType>
@@ -741,8 +746,12 @@ FluxingAdvectionSolver<MeshType>::remap(
           new_variables.push_back(std::make_shared<DiscreteFunctionVariant>(
             DiscreteFunctionT(m_new_mesh, std::get<CellValue<DataType>>(m_remapped_list[i]))));
         } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
-          new_variables.push_back(std::make_shared<DiscreteFunctionVariant>(
-            DiscreteFunctionT(m_new_mesh, std::get<CellArray<DataType>>(m_remapped_list[i]))));
+          if constexpr (std::is_arithmetic_v<DataType>) {
+            new_variables.push_back(std::make_shared<DiscreteFunctionVariant>(
+              DiscreteFunctionT(m_new_mesh, std::get<CellArray<DataType>>(m_remapped_list[i]))));
+          } else {
+            throw NormalError("remapping DiscreteFunctionP0Vector of non arithmetic data type is not supported");
+          }
         } else {
           throw UnexpectedError("invalid discrete function type");
         }
diff --git a/src/scheme/PolynomialReconstruction.cpp b/src/scheme/PolynomialReconstruction.cpp
index 2af32e157..7d2c58ac5 100644
--- a/src/scheme/PolynomialReconstruction.cpp
+++ b/src/scheme/PolynomialReconstruction.cpp
@@ -740,9 +740,13 @@ PolynomialReconstruction::_createMutableDiscreteFunctionDPKVariantList(
             DiscreteFunctionDPk<MeshType::Dimension, DataType>(p_mesh, m_descriptor.degree()));
         } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
           using DataType = std::remove_const_t<std::decay_t<typename DiscreteFunctionT::data_type>>;
-          mutable_discrete_function_dpk_variant_list.push_back(
-            DiscreteFunctionDPkVector<MeshType::Dimension, DataType>(p_mesh, m_descriptor.degree(),
-                                                                     discrete_function.size()));
+          if constexpr (std::is_arithmetic_v<DataType>) {
+            mutable_discrete_function_dpk_variant_list.push_back(
+              DiscreteFunctionDPkVector<MeshType::Dimension, DataType>(p_mesh, m_descriptor.degree(),
+                                                                       discrete_function.size()));
+          } else {
+            throw NotImplementedError("reconstruction of DiscreteFunctionP0Vector of non arithmetic data type");
+          }
         } else {
           // LCOV_EXCL_START
           throw UnexpectedError("unexpected discrete function type");
@@ -945,32 +949,37 @@ PolynomialReconstruction::_build(
                   column_begin += DataType::Dimension;
                 }
               } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
-                using DataType       = std::decay_t<typename DiscreteFunctionT::data_type>;
-                const auto qj_vector = discrete_function[cell_j_id];
-                size_t index         = 0;
-                for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
-                  const CellId cell_i_id = stencil_cell_list[i];
-                  for (size_t l = 0; l < qj_vector.size(); ++l) {
-                    const DataType& qj         = qj_vector[l];
-                    const DataType& qi_qj      = discrete_function[cell_i_id][l] - qj;
-                    B(index, column_begin + l) = qi_qj;
-                  }
-                }
+                using DataType = std::decay_t<typename DiscreteFunctionT::data_type>;
 
-                for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
-                     ++i_symmetry) {
-                  auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
-                  auto ghost_cell_list = ghost_stencil[cell_j_id];
-                  for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
-                    const CellId cell_i_id = ghost_cell_list[i];
+                if constexpr (std::is_arithmetic_v<DataType>) {
+                  const auto qj_vector = discrete_function[cell_j_id];
+                  size_t index         = 0;
+                  for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
+                    const CellId cell_i_id = stencil_cell_list[i];
                     for (size_t l = 0; l < qj_vector.size(); ++l) {
                       const DataType& qj         = qj_vector[l];
                       const DataType& qi_qj      = discrete_function[cell_i_id][l] - qj;
                       B(index, column_begin + l) = qi_qj;
                     }
                   }
+
+                  for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
+                       ++i_symmetry) {
+                    auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
+                    auto ghost_cell_list = ghost_stencil[cell_j_id];
+                    for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
+                      const CellId cell_i_id = ghost_cell_list[i];
+                      for (size_t l = 0; l < qj_vector.size(); ++l) {
+                        const DataType& qj         = qj_vector[l];
+                        const DataType& qi_qj      = discrete_function[cell_i_id][l] - qj;
+                        B(index, column_begin + l) = qi_qj;
+                      }
+                    }
+                  }
+                  column_begin += qj_vector.size();
+                } else {
+                  throw NotImplementedError("reconstruction of DiscreteFunctionP0Vector of non arithmetic data type");
                 }
-                column_begin += qj_vector.size();
               } else {
                 // LCOV_EXCL_START
                 throw UnexpectedError("invalid discrete function type");
-- 
GitLab