From a5303c097e8efb6b87f26bcf1a811813acd21201 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Mon, 19 Apr 2021 22:37:22 +0200
Subject: [PATCH] Improve binary operators handling for DiscreteFunctionP0

- code is more generic
- handles properly mixture of const/non-const data types
---
 src/scheme/DiscreteFunctionP0.hpp | 193 +++++++++++++++++-------------
 1 file changed, 110 insertions(+), 83 deletions(-)

diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp
index a6ffc4cdd..e9c6559e5 100644
--- a/src/scheme/DiscreteFunctionP0.hpp
+++ b/src/scheme/DiscreteFunctionP0.hpp
@@ -50,6 +50,11 @@ class DiscreteFunctionP0 : public IDiscreteFunction
     return m_discrete_function_descriptor;
   }
 
+  operator DiscreteFunctionP0<Dimension, const DataType>() const
+  {
+    return DiscreteFunctionP0<Dimension, const DataType>(m_mesh, m_cell_values);
+  }
+
   PUGS_INLINE
   void
   fill(const DataType& data) const noexcept
@@ -65,72 +70,154 @@ class DiscreteFunctionP0 : public IDiscreteFunction
     return m_cell_values[cell_id];
   }
 
-  friend DiscreteFunctionP0
-  operator+(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
+  template <typename DataType2T>
+  PUGS_INLINE DiscreteFunctionP0<Dimension, decltype(DataType{} + DataType2T{})>
+  operator+(const DiscreteFunctionP0<Dimension, DataType2T>& g) const
   {
+    const DiscreteFunctionP0& f = *this;
     Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
     std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0 sum(mesh);
+    DiscreteFunctionP0<Dimension, decltype(DataType{} + DataType2T{})> sum(mesh);
     parallel_for(
       mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + g[cell_id]; });
     return sum;
   }
 
-  friend DiscreteFunctionP0
-  operator-(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
+  template <typename LHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(LHSDataType{} + DataType{})>
+  operator+(const LHSDataType& a, const DiscreteFunctionP0& g)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(g.mesh());
+    using ProductDataType = decltype(LHSDataType{} + DataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> sum(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = a + g[cell_id]; });
+    return sum;
+  }
+
+  template <typename RHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(DataType{} + RHSDataType{})>
+  operator+(const DiscreteFunctionP0& f, const RHSDataType& b)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(DataType{} + RHSDataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> sum(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + b; });
+    return sum;
+  }
+
+  template <typename DataType2T>
+  PUGS_INLINE DiscreteFunctionP0<Dimension, decltype(DataType{} - DataType2T{})>
+  operator-(const DiscreteFunctionP0<Dimension, DataType2T>& g) const
   {
+    const DiscreteFunctionP0& f = *this;
     Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
     std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0 difference(mesh);
+    DiscreteFunctionP0<Dimension, decltype(DataType{} - DataType2T{})> difference(mesh);
     parallel_for(
       mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - g[cell_id]; });
     return difference;
   }
 
-  friend DiscreteFunctionP0
-  operator*(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
+  template <typename LHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(LHSDataType{} - DataType{})>
+  operator-(const LHSDataType& a, const DiscreteFunctionP0& g)
   {
-    Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
-    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0 product(mesh);
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(g.mesh());
+    using ProductDataType = decltype(LHSDataType{} - DataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> difference(mesh);
     parallel_for(
-      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; });
-    return product;
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = a - g[cell_id]; });
+    return difference;
+  }
+
+  template <typename RHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(DataType{} - RHSDataType{})>
+  operator-(const DiscreteFunctionP0& f, const RHSDataType& b)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(DataType{} - RHSDataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> difference(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - b; });
+    return difference;
   }
 
   template <typename DataType2T>
-  friend DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})>
-  operator*(const DiscreteFunctionP0<Dimension, DataType2T>& f, const DiscreteFunctionP0& g)
+  PUGS_INLINE DiscreteFunctionP0<Dimension, decltype(DataType{} * DataType2T{})>
+  operator*(const DiscreteFunctionP0<Dimension, DataType2T>& g) const
   {
+    const DiscreteFunctionP0& f = *this;
     Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
     std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> product(mesh);
+    DiscreteFunctionP0<Dimension, decltype(DataType{} * DataType2T{})> product(mesh);
     parallel_for(
       mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; });
     return product;
   }
 
-  friend DiscreteFunctionP0
-  operator*(const double& a, const DiscreteFunctionP0& f)
+  template <typename LHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(LHSDataType{} * DataType{})>
+  operator*(const LHSDataType& a, const DiscreteFunctionP0& f)
   {
-    std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0 product(mesh);
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(LHSDataType{} * DataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> product(mesh);
     parallel_for(
       mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = a * f[cell_id]; });
     return product;
   }
 
-  friend DiscreteFunctionP0
-  operator/(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g)
+  template <typename RHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(DataType{} * RHSDataType{})>
+  operator*(const DiscreteFunctionP0& f, const RHSDataType& b)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(DataType{} * RHSDataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> product(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * b; });
+    return product;
+  }
+
+  template <typename DataType2T>
+  PUGS_INLINE DiscreteFunctionP0<Dimension, decltype(DataType{} / DataType2T{})>
+  operator/(const DiscreteFunctionP0<Dimension, DataType2T>& g) const
   {
+    const DiscreteFunctionP0& f = *this;
     Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh");
     std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-    DiscreteFunctionP0 ratio(mesh);
+    DiscreteFunctionP0<Dimension, decltype(DataType{} / DataType2T{})> ratio(mesh);
     parallel_for(
       mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / g[cell_id]; });
     return ratio;
   }
 
+  template <typename LHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(LHSDataType{} / DataType{})>
+  operator/(const LHSDataType& a, const DiscreteFunctionP0& f)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(LHSDataType{} / DataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> ratio(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = a / f[cell_id]; });
+    return ratio;
+  }
+
+  template <typename RHSDataType>
+  PUGS_INLINE friend DiscreteFunctionP0<Dimension, decltype(DataType{} / RHSDataType{})>
+  operator/(const DiscreteFunctionP0& f, const RHSDataType& b)
+  {
+    std::shared_ptr mesh  = std::dynamic_pointer_cast<const MeshType>(f.mesh());
+    using ProductDataType = decltype(DataType{} / RHSDataType{});
+    DiscreteFunctionP0<Dimension, ProductDataType> ratio(mesh);
+    parallel_for(
+      mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / b; });
+    return ratio;
+  }
+
   DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh)
   {
     using MeshDataType      = MeshData<Dimension>;
@@ -155,64 +242,4 @@ class DiscreteFunctionP0 : public IDiscreteFunction
   ~DiscreteFunctionP0() = default;
 };
 
-template <size_t Dimension, size_t ValueDimension>
-DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>
-operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>& f)
-{
-  using MeshType       = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType;
-  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-  DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh);
-  parallel_for(
-    mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; });
-  return product;
-}
-
-template <size_t Dimension, size_t ValueDimension>
-DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
-operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f)
-{
-  using MeshType       = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
-  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-  DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
-  parallel_for(
-    mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; });
-  return product;
-}
-
-template <size_t Dimension, size_t ValueDimension>
-DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
-operator*(const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f, const TinyMatrix<ValueDimension>& A)
-{
-  using MeshType       = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
-  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-  DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
-  parallel_for(
-    mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
-  return product;
-}
-
-template <size_t Dimension, size_t ValueDimension>
-DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>
-operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyMatrix<ValueDimension>& A)
-{
-  using MeshType       = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType;
-  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-  DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh);
-  parallel_for(
-    mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
-  return product;
-}
-
-template <size_t Dimension, size_t ValueDimension>
-DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>
-operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyVector<ValueDimension>& A)
-{
-  using MeshType       = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType;
-  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh());
-  DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh);
-  parallel_for(
-    mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; });
-  return product;
-}
-
 #endif   // DISCRETE_FUNCTION_P0_HPP
-- 
GitLab