From caab4b21c6c7fa8df897dbdbc7f395459b587391 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 6 May 2020 19:16:21 +0200 Subject: [PATCH] Improve the code (design, genericity and readability) - Rename FunctionAdapter -> PugsFunctionAdapter to improve readability - PugsFunctionAdapter is now defined in its own file --- src/language/MeshModule.cpp | 105 ++++++--------------------- src/language/PugsFunctionAdapter.hpp | 90 +++++++++++++++++++++++ 2 files changed, 114 insertions(+), 81 deletions(-) create mode 100644 src/language/PugsFunctionAdapter.hpp diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp index 29d1ca111..cd987e244 100644 --- a/src/language/MeshModule.cpp +++ b/src/language/MeshModule.cpp @@ -2,6 +2,7 @@ #include <language/BuiltinFunctionEmbedder.hpp> #include <language/FunctionTable.hpp> +#include <language/PugsFunctionAdapter.hpp> #include <language/SymbolTable.hpp> #include <language/TypeDescriptor.hpp> #include <language/node_processor/ExecutionPolicy.hpp> @@ -19,97 +20,39 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod template <> inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; +template <typename T> +class MeshTransformation; template <typename OutputType, typename... InputType> -class FunctionAdapter +class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)> { - static constexpr size_t OutputDimension = OutputType::Dimension; - - private: - template <typename T, typename... Args> - static void - _convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context) - { - context[sizeof...(args)] = t; - if constexpr (sizeof...(args) > 0) { - _convertArgs(std::forward<Args>(args)..., context); - } - } - - template <typename... Args> - static void - convertArgs(ExecutionPolicy::Context& context, const Args&... args) - { - static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type"); - _convertArgs(args..., context); - } - - static std::function<OutputType(DataVariant&& result)> - _get_result_converter(ASTNodeDataType data_type) - { - switch (data_type) { - case ASTNodeDataType::list_t: { - return [](DataVariant&& result) -> OutputType { - AggregateDataVariant& v = std::get<AggregateDataVariant>(result); - OutputType x; - for (size_t i = 0; i < x.dimension(); ++i) { - x[i] = std::get<double>(v[i]); - } - return x; - }; - } - case ASTNodeDataType::vector_t: { - return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; - } - case ASTNodeDataType::double_t: { - if constexpr (OutputDimension == 1) { - return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; - } else { - throw UnexpectedError("unexpected data_type"); - } - } - default: { - throw UnexpectedError("unexpected data_type"); - } - } - } + static constexpr size_t Dimension = OutputType::Dimension; + using Adapter = PugsFunctionAdapter<OutputType(InputType...)>; public: - template <size_t Dimension> - static inline std::shared_ptr<Mesh<Connectivity<OutputDimension>>> + static inline std::shared_ptr<Mesh<Connectivity<Dimension>>> transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh) { - auto& symbol_table = function_symbol_id.symbolTable(); - auto& function_expression = *symbol_table.functionTable()[function_symbol_id.id()].definitionNode().children[1]; - auto& function_context = function_expression.m_symbol_table->context(); - - ASTNodeDataType t = function_expression.m_data_type; - auto convert_result = _get_result_converter(t); - - const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); - Array<ExecutionPolicy> context_list(number_of_threads); - for (size_t i = 0; i < context_list.size(); ++i) { - context_list[i] = ExecutionPolicy(ExecutionPolicy{}, - {function_context.id(), - std::make_shared<ExecutionPolicy::Context::Values>(function_context.size())}); - } + using MeshType = Mesh<Connectivity<Dimension>>; + const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); - using MeshType = Mesh<Connectivity<Dimension>>; - const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); - NodeValue<const TinyVector<Dimension>> given_xr = given_mesh.xr(); + auto& expression = Adapter::getFunctionExpression(function_symbol_id); + auto convert_result = Adapter::getResultConverter(expression.m_data_type); + Array<ExecutionPolicy> context_list = Adapter::getContextList(expression); - NodeValue<TinyVector<Dimension>> xr(given_mesh.connectivity()); + NodeValue<const OutputType> given_xr = given_mesh.xr(); + NodeValue<OutputType> xr(given_mesh.connectivity()); using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; - parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, &tokens](NodeId r) { + parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) { const int32_t t = tokens.acquire(); auto& execution_policy = context_list[t]; - convertArgs(execution_policy.currentContext(), given_xr[r]); + Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]); + xr[r] = convert_result(expression.execute(execution_policy)); - xr[r] = convert_result(function_expression.execute(execution_policy)); tokens.release(t); }); @@ -140,19 +83,19 @@ MeshModule::MeshModule() FunctionSymbolId function_id) -> std::shared_ptr<IMesh> { switch (p_mesh->dimension()) { case 1: { - return FunctionAdapter<TinyVector<1>, TinyVector<1>>::transform<1>(function_id, - p_mesh); + using TransformT = TinyVector<1>(TinyVector<1>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); } case 2: { - return FunctionAdapter<TinyVector<2>, TinyVector<2>>::transform<2>(function_id, - p_mesh); + using TransformT = TinyVector<2>(TinyVector<2>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); } case 3: { - return FunctionAdapter<TinyVector<3>, TinyVector<3>>::transform<3>(function_id, - p_mesh); + using TransformT = TinyVector<3>(TinyVector<3>); + return MeshTransformation<TransformT>::transform(function_id, p_mesh); } default: { - throw NormalError("invalid dimension"); + throw NormalError("invalid mesh dimension"); } } }} diff --git a/src/language/PugsFunctionAdapter.hpp b/src/language/PugsFunctionAdapter.hpp new file mode 100644 index 000000000..262299728 --- /dev/null +++ b/src/language/PugsFunctionAdapter.hpp @@ -0,0 +1,90 @@ +#ifndef PUGS_FUNCTION_ADAPTER_HPP +#define PUGS_FUNCTION_ADAPTER_HPP + +#include <language/ASTNode.hpp> +#include <language/SymbolTable.hpp> +#include <language/node_processor/ExecutionPolicy.hpp> +#include <utils/Array.hpp> +#include <utils/Exceptions.hpp> +#include <utils/PugsMacros.hpp> + +#include <Kokkos_Core.hpp> + +template <typename T> +class PugsFunctionAdapter; +template <typename OutputType, typename... InputType> +class PugsFunctionAdapter<OutputType(InputType...)> +{ + private: + template <typename T, typename... Args> + PUGS_INLINE static void + _convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context) + { + context[sizeof...(args)] = t; + if constexpr (sizeof...(args) > 0) { + _convertArgs(std::forward<Args>(args)..., context); + } + } + + protected: + PUGS_INLINE static auto& + getFunctionExpression(FunctionSymbolId function_symbol_id) + { + return *function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()].definitionNode().children[1]; + } + + PUGS_INLINE static auto + getContextList(const ASTNode& expression) + { + Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size()); + auto& context = expression.m_symbol_table->context(); + + for (size_t i = 0; i < context_list.size(); ++i) { + context_list[i] = + ExecutionPolicy(ExecutionPolicy{}, + {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())}); + } + + return context_list; + } + + template <typename... Args> + PUGS_INLINE static void + convertArgs(ExecutionPolicy::Context& context, const Args&... args) + { + static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type"); + _convertArgs(args..., context); + } + + PUGS_INLINE static std::function<OutputType(DataVariant&& result)> + getResultConverter(ASTNodeDataType data_type) + { + switch (data_type) { + case ASTNodeDataType::list_t: { + return [](DataVariant&& result) -> OutputType { + AggregateDataVariant& v = std::get<AggregateDataVariant>(result); + OutputType x; + for (size_t i = 0; i < x.dimension(); ++i) { + x[i] = std::get<double>(v[i]); + } + return x; + }; + } + case ASTNodeDataType::vector_t: { + return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; + } + case ASTNodeDataType::double_t: { + if constexpr (std::is_same_v<OutputType, TinyVector<1>>) { + return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; + } else { + throw UnexpectedError("unexpected data_type"); + } + } + default: { + throw UnexpectedError("unexpected data_type"); + } + } + } +}; + +#endif // PUGS_FUNCTION_ADAPTER_HPP -- GitLab