From c400ee55cf22b50fa2cdcd6fe98d6b55f8a31375 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Sat, 18 Feb 2023 19:31:33 +0100
Subject: [PATCH] Use DiscreteFunctionVariant for writers

---
 src/language/modules/SchemeModule.cpp |  76 ++++++------
 src/language/modules/WriterModule.cpp |   6 +-
 src/language/modules/WriterModule.hpp |   1 -
 src/output/NamedDiscreteFunction.hpp  |  15 +--
 src/output/WriterBase.cpp             | 162 +++-----------------------
 src/output/WriterBase.hpp             |   6 -
 6 files changed, 61 insertions(+), 205 deletions(-)

diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp
index cf1212e08..2e1efdbe2 100644
--- a/src/language/modules/SchemeModule.cpp
+++ b/src/language/modules/SchemeModule.cpp
@@ -419,46 +419,46 @@ SchemeModule::SchemeModule()
 
                               ));
 
-  this
-    ->_addBuiltinFunction("cell_volume",
-                          std::function(
-
-                            [](const std::shared_ptr<const IMesh>& i_mesh) -> std::shared_ptr<const IDiscreteFunction> {
-                              switch (i_mesh->dimension()) {
-                              case 1: {
-                                constexpr size_t Dimension = 1;
-                                using MeshType             = Mesh<Connectivity<Dimension>>;
-                                std::shared_ptr<const MeshType> mesh =
-                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
-
-                                return std::make_shared<const DiscreteFunctionP0<
-                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
-                              }
-                              case 2: {
-                                constexpr size_t Dimension = 2;
-                                using MeshType             = Mesh<Connectivity<Dimension>>;
-                                std::shared_ptr<const MeshType> mesh =
-                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
-
-                                return std::make_shared<const DiscreteFunctionP0<
-                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
-                              }
-                              case 3: {
-                                constexpr size_t Dimension = 3;
-                                using MeshType             = Mesh<Connectivity<Dimension>>;
-                                std::shared_ptr<const MeshType> mesh =
-                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
-
-                                return std::make_shared<const DiscreteFunctionP0<
-                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
-                              }
-                              default: {
-                                throw UnexpectedError("invalid mesh dimension");
-                              }
+  this->_addBuiltinFunction("cell_volume",
+                            std::function(
+
+                              [](const std::shared_ptr<const IMesh>& i_mesh)
+                                -> std::shared_ptr<const DiscreteFunctionVariant> {
+                                switch (i_mesh->dimension()) {
+                                case 1: {
+                                  constexpr size_t Dimension = 1;
+                                  using MeshType             = Mesh<Connectivity<Dimension>>;
+                                  std::shared_ptr<const MeshType> mesh =
+                                    std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
+
+                                  return std::make_shared<DiscreteFunctionVariant>(
+                                    DiscreteFunctionP0(mesh, MeshDataManager::instance().getMeshData(*mesh).Vj()));
+                                }
+                                case 2: {
+                                  constexpr size_t Dimension = 2;
+                                  using MeshType             = Mesh<Connectivity<Dimension>>;
+                                  std::shared_ptr<const MeshType> mesh =
+                                    std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
+
+                                  return std::make_shared<DiscreteFunctionVariant>(
+                                    DiscreteFunctionP0(mesh, MeshDataManager::instance().getMeshData(*mesh).Vj()));
+                                }
+                                case 3: {
+                                  constexpr size_t Dimension = 3;
+                                  using MeshType             = Mesh<Connectivity<Dimension>>;
+                                  std::shared_ptr<const MeshType> mesh =
+                                    std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);
+
+                                  return std::make_shared<DiscreteFunctionVariant>(
+                                    DiscreteFunctionP0(mesh, MeshDataManager::instance().getMeshData(*mesh).Vj()));
+                                }
+                                default: {
+                                  throw UnexpectedError("invalid mesh dimension");
+                                }
+                                }
                               }
-                            }
 
-                            ));
+                              ));
 
   MathFunctionRegisterForVh{*this};
 }
diff --git a/src/language/modules/WriterModule.cpp b/src/language/modules/WriterModule.cpp
index 6523d56a9..aa479c0eb 100644
--- a/src/language/modules/WriterModule.cpp
+++ b/src/language/modules/WriterModule.cpp
@@ -12,9 +12,7 @@
 #include <output/NamedDiscreteFunction.hpp>
 #include <output/NamedItemValueVariant.hpp>
 #include <output/VTKWriter.hpp>
-#include <scheme/DiscreteFunctionP0.hpp>
-#include <scheme/IDiscreteFunction.hpp>
-#include <scheme/IDiscreteFunctionDescriptor.hpp>
+#include <scheme/DiscreteFunctionVariant.hpp>
 
 WriterModule::WriterModule()
 {
@@ -75,7 +73,7 @@ WriterModule::WriterModule()
 
   this->_addBuiltinFunction("name_output", std::function(
 
-                                             [](std::shared_ptr<const IDiscreteFunction> discrete_function,
+                                             [](std::shared_ptr<const DiscreteFunctionVariant> discrete_function,
                                                 const std::string& name) -> std::shared_ptr<const INamedDiscreteData> {
                                                return std::make_shared<const NamedDiscreteFunction>(discrete_function,
                                                                                                     name);
diff --git a/src/language/modules/WriterModule.hpp b/src/language/modules/WriterModule.hpp
index f61eafb0e..97bef5a78 100644
--- a/src/language/modules/WriterModule.hpp
+++ b/src/language/modules/WriterModule.hpp
@@ -7,7 +7,6 @@
 
 class OutputNamedItemValueSet;
 class INamedDiscreteData;
-class IDiscreteFunction;
 
 #include <string>
 
diff --git a/src/output/NamedDiscreteFunction.hpp b/src/output/NamedDiscreteFunction.hpp
index d56d44369..305db7be6 100644
--- a/src/output/NamedDiscreteFunction.hpp
+++ b/src/output/NamedDiscreteFunction.hpp
@@ -6,12 +6,12 @@
 #include <memory>
 #include <string>
 
-class IDiscreteFunction;
+class DiscreteFunctionVariant;
 
 class NamedDiscreteFunction final : public INamedDiscreteData
 {
  private:
-  std::shared_ptr<const IDiscreteFunction> m_discrete_function;
+  std::shared_ptr<const DiscreteFunctionVariant> m_discrete_function_variant;
   std::string m_name;
 
  public:
@@ -27,14 +27,15 @@ class NamedDiscreteFunction final : public INamedDiscreteData
     return m_name;
   }
 
-  const std::shared_ptr<const IDiscreteFunction>
-  discreteFunction() const
+  const std::shared_ptr<const DiscreteFunctionVariant>
+  discreteFunctionVariant() const
   {
-    return m_discrete_function;
+    return m_discrete_function_variant;
   }
 
-  NamedDiscreteFunction(const std::shared_ptr<const IDiscreteFunction>& discrete_function, const std::string& name)
-    : m_discrete_function{discrete_function}, m_name{name}
+  NamedDiscreteFunction(const std::shared_ptr<const DiscreteFunctionVariant>& discrete_function,
+                        const std::string& name)
+    : m_discrete_function_variant{discrete_function}, m_name{name}
   {}
 
   NamedDiscreteFunction(const NamedDiscreteFunction&) = default;
diff --git a/src/output/WriterBase.cpp b/src/output/WriterBase.cpp
index fcb2cf78f..e49e69025 100644
--- a/src/output/WriterBase.cpp
+++ b/src/output/WriterBase.cpp
@@ -7,7 +7,7 @@
 #include <output/OutputNamedItemValueSet.hpp>
 #include <scheme/DiscreteFunctionP0.hpp>
 #include <scheme/DiscreteFunctionP0Vector.hpp>
-#include <scheme/IDiscreteFunction.hpp>
+#include <scheme/DiscreteFunctionVariant.hpp>
 #include <scheme/IDiscreteFunctionDescriptor.hpp>
 #include <utils/Exceptions.hpp>
 
@@ -24,135 +24,6 @@ WriterBase::_registerDiscreteFunction(const std::string& name,
   }
 }
 
-template <size_t Dimension, template <size_t DimensionT, typename DataTypeT> typename DiscreteFunctionType>
-void
-WriterBase::_registerDiscreteFunction(const std::string& name,
-                                      const IDiscreteFunction& i_discrete_function,
-                                      OutputNamedItemDataSet& named_item_data_set)
-{
-  const ASTNodeDataType& data_type = i_discrete_function.dataType();
-  switch (data_type) {
-  case ASTNodeDataType::bool_t: {
-    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, bool>&>(i_discrete_function),
-                              named_item_data_set);
-    break;
-  }
-  case ASTNodeDataType::unsigned_int_t: {
-    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, uint64_t>&>(i_discrete_function),
-                              named_item_data_set);
-    break;
-  }
-  case ASTNodeDataType::int_t: {
-    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, int64_t>&>(i_discrete_function),
-                              named_item_data_set);
-    break;
-  }
-  case ASTNodeDataType::double_t: {
-    _registerDiscreteFunction(name, dynamic_cast<const DiscreteFunctionType<Dimension, double>&>(i_discrete_function),
-                              named_item_data_set);
-    break;
-  }
-  case ASTNodeDataType::vector_t: {
-    if constexpr (DiscreteFunctionType<Dimension, double>::handled_data_type ==
-                  IDiscreteFunction::HandledItemDataType::vector) {
-      throw UnexpectedError("invalid data type for vector data");
-    } else {
-      switch (data_type.dimension()) {
-      case 1: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<1, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      case 2: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<2, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      case 3: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyVector<3, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      default: {
-        throw UnexpectedError("invalid vector dimension");
-      }
-      }
-    }
-    break;
-  }
-  case ASTNodeDataType::matrix_t: {
-    if constexpr (DiscreteFunctionType<Dimension, double>::handled_data_type ==
-                  IDiscreteFunction::HandledItemDataType::vector) {
-      throw UnexpectedError("invalid data type for vector data");
-    } else {
-      Assert(data_type.numberOfRows() == data_type.numberOfColumns(), "invalid matrix dimensions");
-      switch (data_type.numberOfRows()) {
-      case 1: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<1, 1, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      case 2: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<2, 2, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      case 3: {
-        _registerDiscreteFunction(name,
-                                  dynamic_cast<const DiscreteFunctionType<Dimension, TinyMatrix<3, 3, double>>&>(
-                                    i_discrete_function),
-                                  named_item_data_set);
-        break;
-      }
-      default: {
-        throw UnexpectedError("invalid matrix dimension");
-      }
-      }
-    }
-    break;
-  }
-  default: {
-    throw UnexpectedError("invalid data type " + dataTypeName(data_type));
-  }
-  }
-}
-
-template <template <size_t Dimension, typename DataType> typename DiscreteFunctionType>
-void
-WriterBase::_registerDiscreteFunction(const NamedDiscreteFunction& named_discrete_function,
-                                      OutputNamedItemDataSet& named_item_data_set)
-{
-  const IDiscreteFunction& i_discrete_function = *named_discrete_function.discreteFunction();
-  const std::string& name                      = named_discrete_function.name();
-  switch (i_discrete_function.mesh()->dimension()) {
-  case 1: {
-    _registerDiscreteFunction<1, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
-    break;
-  }
-  case 2: {
-    _registerDiscreteFunction<2, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
-    break;
-  }
-  case 3: {
-    _registerDiscreteFunction<3, DiscreteFunctionType>(name, i_discrete_function, named_item_data_set);
-    break;
-  }
-  default: {
-    throw UnexpectedError("invalid mesh dimension");
-  }
-  }
-}
-
 void
 WriterBase::_checkConnectivity(
   const std::shared_ptr<const IMesh>& mesh,
@@ -211,7 +82,11 @@ WriterBase::_checkMesh(const std::shared_ptr<const IMesh>& mesh,
       const NamedDiscreteFunction& named_discrete_function =
         dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);
 
-      if (mesh != named_discrete_function.discreteFunction()->mesh()) {
+      std::shared_ptr<const IMesh> discrete_function_mesh =
+        std::visit([](auto&& f) { return f.mesh(); },
+                   named_discrete_function.discreteFunctionVariant()->discreteFunction());
+
+      if (mesh != discrete_function_mesh) {
         std::ostringstream error_msg;
         error_msg << "The variable " << rang::fgB::yellow << named_discrete_function.name() << rang::fg::reset
                   << " is not defined on the provided mesh\n";
@@ -237,7 +112,8 @@ WriterBase::_getMesh(const std::vector<std::shared_ptr<const INamedDiscreteData>
       const NamedDiscreteFunction& named_discrete_function =
         dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);
 
-      std::shared_ptr mesh = named_discrete_function.discreteFunction()->mesh();
+      std::shared_ptr mesh = std::visit([&](auto&& f) { return f.mesh(); },
+                                        named_discrete_function.discreteFunctionVariant()->discreteFunction());
       mesh_set[mesh]       = named_discrete_function.name();
 
       switch (mesh->dimension()) {
@@ -317,24 +193,12 @@ WriterBase::_getOutputNamedItemDataSet(
       const NamedDiscreteFunction& named_discrete_function =
         dynamic_cast<const NamedDiscreteFunction&>(*named_discrete_data);
 
-      const IDiscreteFunction& i_discrete_function = *named_discrete_function.discreteFunction();
+      const std::string& name = named_discrete_function.name();
 
-      switch (i_discrete_function.descriptor().type()) {
-      case DiscreteFunctionType::P0: {
-        WriterBase::_registerDiscreteFunction<DiscreteFunctionP0>(named_discrete_function, named_item_data_set);
-        break;
-      }
-      case DiscreteFunctionType::P0Vector: {
-        WriterBase::_registerDiscreteFunction<DiscreteFunctionP0Vector>(named_discrete_function, named_item_data_set);
-        break;
-      }
-      default: {
-        std::ostringstream error_msg;
-        error_msg << "the type of discrete function of " << rang::fgB::blue << named_discrete_data->name()
-                  << rang::style::reset << " is not supported";
-        throw NormalError(error_msg.str());
-      }
-      }
+      const DiscreteFunctionVariant& discrete_function_variant = *named_discrete_function.discreteFunctionVariant();
+
+      std::visit([&](auto&& f) { WriterBase::_registerDiscreteFunction(name, f, named_item_data_set); },
+                 discrete_function_variant.discreteFunction());
       break;
     }
     case INamedDiscreteData::Type::item_value: {
diff --git a/src/output/WriterBase.hpp b/src/output/WriterBase.hpp
index 339cd7645..ae57e9af2 100644
--- a/src/output/WriterBase.hpp
+++ b/src/output/WriterBase.hpp
@@ -90,12 +90,6 @@ class WriterBase : public IWriter
   template <typename DiscreteFunctionType>
   static void _registerDiscreteFunction(const std::string& name, const DiscreteFunctionType&, OutputNamedItemDataSet&);
 
-  template <size_t Dimension, template <size_t DimensionT, typename DataTypeT> typename DiscreteFunctionType>
-  static void _registerDiscreteFunction(const std::string& name, const IDiscreteFunction&, OutputNamedItemDataSet&);
-
-  template <template <size_t DimensionT, typename DataTypeT> typename DiscreteFunctionType>
-  static void _registerDiscreteFunction(const NamedDiscreteFunction&, OutputNamedItemDataSet&);
-
  protected:
   void _checkConnectivity(const std::shared_ptr<const IMesh>& mesh,
                           const std::vector<std::shared_ptr<const INamedDiscreteData>>& named_discrete_data_list) const;
-- 
GitLab