diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp index 0a0b559411a18d395468712e8355f2d7e80aecb8..29d1ca1110d469d58fc667bd047ebe0b0eb878c1 100644 --- a/src/language/MeshModule.cpp +++ b/src/language/MeshModule.cpp @@ -19,6 +19,104 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod template <> inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; +template <typename OutputType, typename... InputType> +class FunctionAdapter +{ + static constexpr size_t OutputDimension = OutputType::Dimension; + + private: + template <typename T, typename... Args> + static void + _convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context) + { + context[sizeof...(args)] = t; + if constexpr (sizeof...(args) > 0) { + _convertArgs(std::forward<Args>(args)..., context); + } + } + + template <typename... Args> + static void + convertArgs(ExecutionPolicy::Context& context, const Args&... args) + { + static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type"); + _convertArgs(args..., context); + } + + static std::function<OutputType(DataVariant&& result)> + _get_result_converter(ASTNodeDataType data_type) + { + switch (data_type) { + case ASTNodeDataType::list_t: { + return [](DataVariant&& result) -> OutputType { + AggregateDataVariant& v = std::get<AggregateDataVariant>(result); + OutputType x; + for (size_t i = 0; i < x.dimension(); ++i) { + x[i] = std::get<double>(v[i]); + } + return x; + }; + } + case ASTNodeDataType::vector_t: { + return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); }; + } + case ASTNodeDataType::double_t: { + if constexpr (OutputDimension == 1) { + return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; }; + } else { + throw UnexpectedError("unexpected data_type"); + } + } + default: { + throw UnexpectedError("unexpected data_type"); + } + } + } + + public: + template <size_t Dimension> + static inline std::shared_ptr<Mesh<Connectivity<OutputDimension>>> + transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh) + { + auto& symbol_table = function_symbol_id.symbolTable(); + auto& function_expression = *symbol_table.functionTable()[function_symbol_id.id()].definitionNode().children[1]; + auto& function_context = function_expression.m_symbol_table->context(); + + ASTNodeDataType t = function_expression.m_data_type; + auto convert_result = _get_result_converter(t); + + const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); + Array<ExecutionPolicy> context_list(number_of_threads); + for (size_t i = 0; i < context_list.size(); ++i) { + context_list[i] = ExecutionPolicy(ExecutionPolicy{}, + {function_context.id(), + std::make_shared<ExecutionPolicy::Context::Values>(function_context.size())}); + } + + using MeshType = Mesh<Connectivity<Dimension>>; + const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); + NodeValue<const TinyVector<Dimension>> given_xr = given_mesh.xr(); + + NodeValue<TinyVector<Dimension>> xr(given_mesh.connectivity()); + + using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; + Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; + + parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, &tokens](NodeId r) { + const int32_t t = tokens.acquire(); + + auto& execution_policy = context_list[t]; + + convertArgs(execution_policy.currentContext(), given_xr[r]); + + xr[r] = convert_result(function_expression.execute(execution_policy)); + tokens.release(t); + }); + + return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); + } +}; + MeshModule::MeshModule() { this->_addTypeDescriptor( @@ -40,61 +138,21 @@ MeshModule::MeshModule() std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{ [](std::shared_ptr<IMesh> p_mesh, FunctionSymbolId function_id) -> std::shared_ptr<IMesh> { - auto& symbol_table = function_id.symbolTable(); - auto& function_expression = - *symbol_table.functionTable()[function_id.id()].definitionNode().children[1]; - auto& function_context = function_expression.m_symbol_table->context(); - - const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); - Array<ExecutionPolicy> context_list(number_of_threads); - for (size_t i = 0; i < context_list.size(); ++i) { - context_list[i] = - ExecutionPolicy(ExecutionPolicy{}, - {function_context.id(), - std::make_shared<ExecutionPolicy::Context::Values>( - function_context.size())}); - } - switch (p_mesh->dimension()) { case 1: { - throw NotImplementedError("not implemented in 1d"); - break; + return FunctionAdapter<TinyVector<1>, TinyVector<1>>::transform<1>(function_id, + p_mesh); } case 2: { - throw NotImplementedError("not implemented in 2d"); - break; + return FunctionAdapter<TinyVector<2>, TinyVector<2>>::transform<2>(function_id, + p_mesh); } case 3: { - using MeshType = Mesh<Connectivity3D>; - const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); - NodeValue<const TinyVector<3>> given_xr = given_mesh.xr(); - - NodeValue<TinyVector<3>> xr(given_mesh.connectivity()); - - using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; - Kokkos::Experimental::UniqueToken<execution_space, - Kokkos::Experimental::UniqueTokenScope::Global> - tokens; - - parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, - &tokens](NodeId r) { - const int32_t t = tokens.acquire(); - - auto& execution_policy = context_list[t]; - execution_policy.currentContext()[0] = given_xr[r]; - - auto&& value = function_expression.execute(execution_policy); - - AggregateDataVariant& v = std::get<AggregateDataVariant>(value); - xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])}; - - tokens.release(t); - }); - - return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); + return FunctionAdapter<TinyVector<3>, TinyVector<3>>::transform<3>(function_id, + p_mesh); } default: { - return nullptr; + throw NormalError("invalid dimension"); } } }}