diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index 5d2c577cbff51ec66e5c92c41ec1e757b94a0313..61767078aeaa0f8362393196a01ce9cde167a064 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -13,52 +13,11 @@ #include <mesh/GmshReader.hpp> #include <mesh/Mesh.hpp> #include <mesh/MeshInterpoler.hpp> +#include <mesh/MeshTransformer.hpp> #include <utils/Exceptions.hpp> #include <Kokkos_Core.hpp> -template <typename T> -class MeshTransformation; -template <typename OutputType, typename InputType> -class MeshTransformation<OutputType(InputType)> : public PugsFunctionAdapter<OutputType(InputType)> -{ - static constexpr size_t Dimension = OutputType::Dimension; - using Adapter = PugsFunctionAdapter<OutputType(InputType)>; - - public: - static inline std::shared_ptr<Mesh<Connectivity<Dimension>>> - transform(const FunctionSymbolId& function_symbol_id, std::shared_ptr<const IMesh> p_mesh) - { - using MeshType = Mesh<Connectivity<Dimension>>; - const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); - - auto& expression = Adapter::getFunctionExpression(function_symbol_id); - auto convert_result = Adapter::getResultConverter(expression.m_data_type); - - Array<ExecutionPolicy> context_list = Adapter::getContextList(expression); - - NodeValue<const InputType> given_xr = given_mesh.xr(); - NodeValue<OutputType> xr(given_mesh.connectivity()); - - using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; - Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; - - parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) { - const int32_t t = tokens.acquire(); - - auto& execution_policy = context_list[t]; - - Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]); - auto result = expression.execute(execution_policy); - xr[r] = convert_result(std::move(result)); - - tokens.release(t); - }); - - return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); - } -}; - MeshModule::MeshModule() { this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IMesh>>); @@ -79,23 +38,7 @@ MeshModule::MeshModule() [](std::shared_ptr<const IMesh> p_mesh, const FunctionSymbolId& function_id) -> std::shared_ptr<const IMesh> { - switch (p_mesh->dimension()) { - case 1: { - using TransformT = TinyVector<1>(TinyVector<1>); - return MeshTransformation<TransformT>::transform(function_id, p_mesh); - } - case 2: { - using TransformT = TinyVector<2>(TinyVector<2>); - return MeshTransformation<TransformT>::transform(function_id, p_mesh); - } - case 3: { - using TransformT = TinyVector<3>(TinyVector<3>); - return MeshTransformation<TransformT>::transform(function_id, p_mesh); - } - default: { - throw UnexpectedError("invalid mesh dimension"); - } - } + return MeshTransformer{}.transform(function_id, p_mesh); } )); diff --git a/src/mesh/CMakeLists.txt b/src/mesh/CMakeLists.txt index a6ededf56bbb4b0ad28fed7cf0d09eab97490350..f536bae7afe8b6d1bc979e4abd9feccaf7c86a7d 100644 --- a/src/mesh/CMakeLists.txt +++ b/src/mesh/CMakeLists.txt @@ -24,6 +24,7 @@ add_library( MeshLineNodeBoundary.cpp MeshNodeBoundary.cpp MeshRandomizer.cpp + MeshTransformer.cpp SynchronizerManager.cpp) # Additional dependencies diff --git a/src/mesh/MeshTransformer.cpp b/src/mesh/MeshTransformer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..67b3688926dcfb765c44799cff4674955515eb8b --- /dev/null +++ b/src/mesh/MeshTransformer.cpp @@ -0,0 +1,71 @@ +#include <mesh/MeshTransformer.hpp> + +#include <mesh/Connectivity.hpp> +#include <mesh/Mesh.hpp> + +#include <language/utils/FunctionTable.hpp> +#include <language/utils/PugsFunctionAdapter.hpp> +#include <language/utils/SymbolTable.hpp> + +template <typename OutputType, typename InputType> +class MeshTransformer::MeshTransformation<OutputType(InputType)> : public PugsFunctionAdapter<OutputType(InputType)> +{ + static constexpr size_t Dimension = OutputType::Dimension; + using Adapter = PugsFunctionAdapter<OutputType(InputType)>; + + public: + static inline std::shared_ptr<Mesh<Connectivity<Dimension>>> + transform(const FunctionSymbolId& function_symbol_id, std::shared_ptr<const IMesh> p_mesh) + { + using MeshType = Mesh<Connectivity<Dimension>>; + const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); + + auto& expression = Adapter::getFunctionExpression(function_symbol_id); + auto convert_result = Adapter::getResultConverter(expression.m_data_type); + + Array<ExecutionPolicy> context_list = Adapter::getContextList(expression); + + NodeValue<const InputType> given_xr = given_mesh.xr(); + NodeValue<OutputType> xr(given_mesh.connectivity()); + + using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; + Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; + + parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) { + const int32_t t = tokens.acquire(); + + auto& execution_policy = context_list[t]; + + Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]); + auto result = expression.execute(execution_policy); + xr[r] = convert_result(std::move(result)); + + tokens.release(t); + }); + + return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); + } +}; + +std::shared_ptr<const IMesh> +MeshTransformer::transform(const FunctionSymbolId& function_id, std::shared_ptr<const IMesh> p_mesh) + +{ + switch (p_mesh->dimension()) { + case 1: { + using TransformT = TinyVector<1>(TinyVector<1>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); + } + case 2: { + using TransformT = TinyVector<2>(TinyVector<2>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); + } + case 3: { + using TransformT = TinyVector<3>(TinyVector<3>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} diff --git a/src/mesh/MeshTransformer.hpp b/src/mesh/MeshTransformer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ce0a4d47eab0557004234d22ea5331a0d2b0f793 --- /dev/null +++ b/src/mesh/MeshTransformer.hpp @@ -0,0 +1,26 @@ +#ifndef MESH_TRANSFORMER_HPP +#define MESH_TRANSFORMER_HPP + +class IMesh; + +template <typename ConnectivityType> +class Mesh; + +class FunctionSymbolId; + +#include <memory> + +class MeshTransformer +{ + template <typename T> + class MeshTransformation; + + public: + std::shared_ptr<const IMesh> transform(const FunctionSymbolId& function_symbol_id, + std::shared_ptr<const IMesh> p_mesh); + + MeshTransformer() = default; + ~MeshTransformer() = default; +}; + +#endif // MESH_TRANSFORMER_HPP