diff --git a/src/language/modules/MathFunctionRegisterForVh.cpp b/src/language/modules/MathFunctionRegisterForVh.cpp index ca47bb9ac036ff607f4680163f4ac77f223a08ca..7617463648fdee030760ef65ed9dec50bdb21232 100644 --- a/src/language/modules/MathFunctionRegisterForVh.cpp +++ b/src/language/modules/MathFunctionRegisterForVh.cpp @@ -251,4 +251,51 @@ MathFunctionRegisterForVh::MathFunctionRegisterForVh(SchemeModule& scheme_module std::shared_ptr<const IDiscreteFunction>(std::shared_ptr<const IDiscreteFunction>, double)>>( [](std::shared_ptr<const IDiscreteFunction> a, double b) -> std::shared_ptr<const IDiscreteFunction> { return max(a, b); })); + + scheme_module + ._addBuiltinFunction("sum_to_R", + std::make_shared<BuiltinFunctionEmbedder<double(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> double { return sum_to<double>(a); })); + + scheme_module._addBuiltinFunction("sum_to_R1", + std::make_shared< + BuiltinFunctionEmbedder<TinyVector<1>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<1> { + return sum_to<TinyVector<1>>(a); + })); + + scheme_module._addBuiltinFunction("sum_to_R2", + std::make_shared< + BuiltinFunctionEmbedder<TinyVector<2>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<2> { + return sum_to<TinyVector<2>>(a); + })); + + scheme_module._addBuiltinFunction("sum_to_R3", + std::make_shared< + BuiltinFunctionEmbedder<TinyVector<3>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyVector<3> { + return sum_to<TinyVector<3>>(a); + })); + + scheme_module._addBuiltinFunction("sum_to_R1x1", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<1>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<1> { + return sum_to<TinyMatrix<1>>(a); + })); + + scheme_module._addBuiltinFunction("sum_to_R2x2", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<2>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<2> { + return sum_to<TinyMatrix<2>>(a); + })); + + scheme_module._addBuiltinFunction("sum_to_R3x3", + std::make_shared< + BuiltinFunctionEmbedder<TinyMatrix<3>(std::shared_ptr<const IDiscreteFunction>)>>( + [](std::shared_ptr<const IDiscreteFunction> a) -> TinyMatrix<3> { + return sum_to<TinyMatrix<3>>(a); + })); } diff --git a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp index e980c91faea9c303963f270c750025e7b1afe219..0c5f23fb63210a8f72ddda25688fc758f743c13c 100644 --- a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp +++ b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.cpp @@ -9,7 +9,7 @@ #define DISCRETE_FUNCTION_CALL(FUNCTION, ARG) \ if (ARG->dataType() == ASTNodeDataType::double_t and ARG->descriptor().type() == DiscreteFunctionType::P0) { \ - switch (f->mesh()->dimension()) { \ + switch (ARG->mesh()->dimension()) { \ case 1: { \ using DiscreteFunctionType = DiscreteFunctionP0<1, double>; \ return std::make_shared<const DiscreteFunctionType>(FUNCTION(dynamic_cast<const DiscreteFunctionType&>(*ARG))); \ @@ -719,3 +719,44 @@ max(const std::shared_ptr<const IDiscreteFunction>& f, const double a) throw NormalError(os.str()); } } + +template <typename ValueT> +ValueT +sum_to(const std::shared_ptr<const IDiscreteFunction>& f) +{ + if (f->dataType() == ast_node_data_type_from<ValueT> and f->descriptor().type() == DiscreteFunctionType::P0) { + switch (f->mesh()->dimension()) { + case 1: { + using DiscreteFunctionType = DiscreteFunctionP0<1, ValueT>; + return sum(dynamic_cast<const DiscreteFunctionType&>(*f)); + } + case 2: { + using DiscreteFunctionType = DiscreteFunctionP0<2, ValueT>; + return sum(dynamic_cast<const DiscreteFunctionType&>(*f)); + } + case 3: { + using DiscreteFunctionType = DiscreteFunctionP0<3, ValueT>; + return sum(dynamic_cast<const DiscreteFunctionType&>(*f)); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } + } else { + throw NormalError("invalid operand type " + operand_type_name(f)); + } +} + +template double sum_to<double>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyVector<1> sum_to<TinyVector<1>>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyVector<2> sum_to<TinyVector<2>>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyVector<3> sum_to<TinyVector<3>>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyMatrix<1> sum_to<TinyMatrix<1>>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyMatrix<2> sum_to<TinyMatrix<2>>(const std::shared_ptr<const IDiscreteFunction>&); + +template TinyMatrix<3> sum_to<TinyMatrix<3>>(const std::shared_ptr<const IDiscreteFunction>&); diff --git a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp index abcfc50b4bd98899d937ef2976b2d38a840125eb..6eef58101016b95411613677cef2d27e5e52eac1 100644 --- a/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp +++ b/src/language/utils/EmbeddedIDiscreteFunctionMathFunctions.hpp @@ -83,4 +83,7 @@ std::shared_ptr<const IDiscreteFunction> max(const double, const std::shared_ptr std::shared_ptr<const IDiscreteFunction> max(const std::shared_ptr<const IDiscreteFunction>&, const double); +template <typename ValueT> +ValueT sum_to(const std::shared_ptr<const IDiscreteFunction>&); + #endif // EMBEDDED_I_DISCRETE_FUNCTION_MATH_FUNCTIONS_HPP diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp index 943b3e6bee19594e0311517a592d7ca7f5789119..ccedb957d9e1fdb9b0c6643d66618709a9717cad 100644 --- a/src/scheme/DiscreteFunctionP0.hpp +++ b/src/scheme/DiscreteFunctionP0.hpp @@ -652,6 +652,13 @@ class DiscreteFunctionP0 : public IDiscreteFunction return result; } + PUGS_INLINE friend DataType + sum(const DiscreteFunctionP0& f) + { + Assert(f.m_cell_values.isBuilt()); + return sum(f.m_cell_values); + } + DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh) : m_mesh{mesh}, m_cell_values{mesh->connectivity()} {} DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const CellValue<DataType>& cell_value) diff --git a/src/utils/Messenger.hpp b/src/utils/Messenger.hpp index 99de952b4d9e084a7981e0e0f43fe5feb4f8b4ab..7e8b30c1a8e92c05fc30178e325110617c73c823 100644 --- a/src/utils/Messenger.hpp +++ b/src/utils/Messenger.hpp @@ -510,7 +510,6 @@ class Messenger allReduceSum(const DataType& data) const { static_assert(not std::is_const_v<DataType>); - static_assert(std::is_arithmetic_v<DataType>); static_assert(not std::is_same_v<DataType, bool>); #ifdef PUGS_HAS_MPI