From a2683e800c434175703bd45e6b64109a22b34a52 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?St=C3=A9phane=20Del=20Pino?= <stephane.delpino44@gmail.com>
Date: Tue, 25 May 2021 22:46:45 +0200
Subject: [PATCH] Add sum_to_* functions

These are `sum_to_R`, `sum_to_R1`, `sum_to_R2`, `sum_to_R3`,
`sum_to_R1x1`, `sum_to_R2x2` and `sum_to_R3x3`.

In C++, it is just `sum`, since the argument is a concrete
`DiscreteFunctionP0<Dimension, DataType>`, there is no ambiguity on
the return type (`DataType`)
---
 .../modules/MathFunctionRegisterForVh.cpp     | 47 +++++++++++++++++++
 ...EmbeddedIDiscreteFunctionMathFunctions.cpp | 43 ++++++++++++++++-
 ...EmbeddedIDiscreteFunctionMathFunctions.hpp |  3 ++
 src/scheme/DiscreteFunctionP0.hpp             |  7 +++
 src/utils/Messenger.hpp                       |  1 -
 5 files changed, 99 insertions(+), 2 deletions(-)

diff --git a/src/language/modules/MathFunctionRegisterForVh.cpp b/src/language/modules/MathFunctionRegisterForVh.cpp
index ca47bb9ac..761746364 100644
--- a/src/language/modules/MathFunctionRegisterForVh.cpp
+++ b/src/language/modules/MathFunctionRegisterForVh.cpp
@@ -251,4 +251,51 @@ MathFunctionRegisterForVh::MathFunctionRegisterForVh(SchemeModule& scheme_module
                            std::shared_ptr<const IDiscreteFunction>(std::shared_ptr<const IDiscreteFunction>, double)>>(
                            [](std::shared_ptr<const IDiscreteFunction> a,
                               double b) -> std::shared_ptr<const IDiscreteFunction> { return max(a, b); }));
+
+  scheme_module
+    ._addBuiltinFunction("sum_to_R",
+                         std::make_shared<BuiltinFunctionEmbedder<double(std::shared_ptr<const IDiscreteFunction>)>>(
+                           [](std::shared_ptr<const IDiscreteFunction> a) -> double { return sum_to<double>(a); }));
+
+  scheme_module._addBuiltinFunction("sum_to_R1",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyVector<1>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<1> {
+                                        return sum_to<TinyVector<1>>(a);
+                                      }));
+
+  scheme_module._addBuiltinFunction("sum_to_R2",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyVector<2>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<2> {
+                                        return sum_to<TinyVector<2>>(a);
+                                      }));
+
+  scheme_module._addBuiltinFunction("sum_to_R3",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyVector<3>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<3> {
+                                        return sum_to<TinyVector<3>>(a);
+                                      }));
+
+  scheme_module._addBuiltinFunction("sum_to_R1x1",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyMatrix<1>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<1> {
+                                        return sum_to<TinyMatrix<1>>(a);
+                                      }));
+
+  scheme_module._addBuiltinFunction("sum_to_R2x2",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyMatrix<2>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<2> {
+                                        return sum_to<TinyMatrix<2>>(a);
+                                      }));
+
+  scheme_module._addBuiltinFunction("sum_to_R3x3",
+                                    std::make_shared<
+                                      BuiltinFunctionEmbedder<TinyMatrix<3>(std::shared_ptr<const IDiscreteFunction>)>>(
+                                      [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<3> {
+                                        return sum_to<TinyMatrix<3>>(a);
+                                      }));
 }
diff --git a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp
index e980c91fa..0c5f23fb6 100644
--- a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp
+++ b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp
@@ -9,7 +9,7 @@
 
 #define DISCRETE_FUNCTION_CALL(FUNCTION, ARG)                                                                         \
   if (ARG->dataType() == ASTNodeDataType::double_t and ARG->descriptor().type() == DiscreteFunctionType::P0) {        \
-    switch (f->mesh()->dimension()) {                                                                                 \
+    switch (ARG->mesh()->dimension()) {                                                                               \
     case 1: {                                                                                                         \
       using DiscreteFunctionType = DiscreteFunctionP0<1, double>;                                                     \
       return std::make_shared<const DiscreteFunctionType>(FUNCTION(dynamic_cast<const DiscreteFunctionType&>(*ARG))); \
@@ -719,3 +719,44 @@ max(const std::shared_ptr<const IDiscreteFunction>& f, const double a)
     throw NormalError(os.str());
   }
 }
+
+template <typename ValueT>
+ValueT
+sum_to(const std::shared_ptr<const IDiscreteFunction>& f)
+{
+  if (f->dataType() == ast_node_data_type_from<ValueT> and f->descriptor().type() == DiscreteFunctionType::P0) {
+    switch (f->mesh()->dimension()) {
+    case 1: {
+      using DiscreteFunctionType = DiscreteFunctionP0<1, ValueT>;
+      return sum(dynamic_cast<const DiscreteFunctionType&>(*f));
+    }
+    case 2: {
+      using DiscreteFunctionType = DiscreteFunctionP0<2, ValueT>;
+      return sum(dynamic_cast<const DiscreteFunctionType&>(*f));
+    }
+    case 3: {
+      using DiscreteFunctionType = DiscreteFunctionP0<3, ValueT>;
+      return sum(dynamic_cast<const DiscreteFunctionType&>(*f));
+    }
+    default: {
+      throw UnexpectedError("invalid mesh dimension");
+    }
+    }
+  } else {
+    throw NormalError("invalid operand type " + operand_type_name(f));
+  }
+}
+
+template double sum_to<double>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyVector<1> sum_to<TinyVector<1>>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyVector<2> sum_to<TinyVector<2>>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyVector<3> sum_to<TinyVector<3>>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyMatrix<1> sum_to<TinyMatrix<1>>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyMatrix<2> sum_to<TinyMatrix<2>>(const std::shared_ptr<const IDiscreteFunction>&);
+
+template TinyMatrix<3> sum_to<TinyMatrix<3>>(const std::shared_ptr<const IDiscreteFunction>&);
diff --git a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp
index abcfc50b4..6eef58101 100644
--- a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp
+++ b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp
@@ -83,4 +83,7 @@ std::shared_ptr<const IDiscreteFunction> max(const double, const std::shared_ptr
 
 std::shared_ptr<const IDiscreteFunction> max(const std::shared_ptr<const IDiscreteFunction>&, const double);
 
+template <typename ValueT>
+ValueT sum_to(const std::shared_ptr<const IDiscreteFunction>&);
+
 #endif   // EMBEDDED_I_DISCRETE_FUNCTION_MATH_FUNCTIONS_HPP
diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp
index 943b3e6be..ccedb957d 100644
--- a/src/scheme/DiscreteFunctionP0.hpp
+++ b/src/scheme/DiscreteFunctionP0.hpp
@@ -652,6 +652,13 @@ class DiscreteFunctionP0 : public IDiscreteFunction
     return result;
   }
 
+  PUGS_INLINE friend DataType
+  sum(const DiscreteFunctionP0& f)
+  {
+    Assert(f.m_cell_values.isBuilt());
+    return sum(f.m_cell_values);
+  }
+
   DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh) : m_mesh{mesh}, m_cell_values{mesh->connectivity()} {}
 
   DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const CellValue<DataType>& cell_value)
diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp
index 99de952b4..7e8b30c1a 100644
--- a/src/utils/Messenger.hpp
+++ b/src/utils/Messenger.hpp
@@ -510,7 +510,6 @@ class Messenger
   allReduceSum(const DataType& data) const
   {
     static_assert(not std::is_const_v<DataType>);
-    static_assert(std::is_arithmetic_v<DataType>);
     static_assert(not std::is_same_v<DataType, bool>);
 
 #ifdef PUGS_HAS_MPI
-- 
GitLab