Skip to content
Snippets Groups Projects
Commit 8ed59aba authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Define FunctionAdapter: pugs function caller helper

This
- improves a lot implementation
- manages correctly language execution context (for multi-thread calls)

This is an important step, but some work is still necessary: one
should define an interface to ease this kind of construction.
parent 9d6fece6
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -19,82 +19,140 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod ...@@ -19,82 +19,140 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
MeshModule::MeshModule() template <typename OutputType, typename... InputType>
class FunctionAdapter
{ {
this->_addTypeDescriptor( static constexpr size_t OutputDimension = OutputType::Dimension;
std::make_shared<TypeDescriptor>(ast_node_data_type_from<std::shared_ptr<IMesh>>.typeName()));
this->_addBuiltinFunction("readGmsh", std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::string>>( private:
std::function<std::shared_ptr<IMesh>(std::string)>{ template <typename T, typename... Args>
static void
_convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context)
{
context[sizeof...(args)] = t;
if constexpr (sizeof...(args) > 0) {
_convertArgs(std::forward<Args>(args)..., context);
}
}
[](const std::string& file_name) -> std::shared_ptr<IMesh> { template <typename... Args>
GmshReader gmsh_reader(file_name); static void
return gmsh_reader.mesh(); convertArgs(ExecutionPolicy::Context& context, const Args&... args)
}} {
static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
_convertArgs(args..., context);
}
)); static std::function<OutputType(DataVariant&& result)>
_get_result_converter(ASTNodeDataType data_type)
{
switch (data_type) {
case ASTNodeDataType::list_t: {
return [](DataVariant&& result) -> OutputType {
AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
OutputType x;
for (size_t i = 0; i < x.dimension(); ++i) {
x[i] = std::get<double>(v[i]);
}
return x;
};
}
case ASTNodeDataType::vector_t: {
return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
}
case ASTNodeDataType::double_t: {
if constexpr (OutputDimension == 1) {
return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
} else {
throw UnexpectedError("unexpected data_type");
}
}
default: {
throw UnexpectedError("unexpected data_type");
}
}
}
this->_addBuiltinFunction("transform", public:
std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, template <size_t Dimension>
FunctionSymbolId>>( static inline std::shared_ptr<Mesh<Connectivity<OutputDimension>>>
std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{ transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
[](std::shared_ptr<IMesh> p_mesh, {
FunctionSymbolId function_id) -> std::shared_ptr<IMesh> { auto& symbol_table = function_symbol_id.symbolTable();
auto& symbol_table = function_id.symbolTable(); auto& function_expression = *symbol_table.functionTable()[function_symbol_id.id()].definitionNode().children[1];
auto& function_expression =
*symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
auto& function_context = function_expression.m_symbol_table->context(); auto& function_context = function_expression.m_symbol_table->context();
ASTNodeDataType t = function_expression.m_data_type;
auto convert_result = _get_result_converter(t);
const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size();
Array<ExecutionPolicy> context_list(number_of_threads); Array<ExecutionPolicy> context_list(number_of_threads);
for (size_t i = 0; i < context_list.size(); ++i) { for (size_t i = 0; i < context_list.size(); ++i) {
context_list[i] = context_list[i] = ExecutionPolicy(ExecutionPolicy{},
ExecutionPolicy(ExecutionPolicy{},
{function_context.id(), {function_context.id(),
std::make_shared<ExecutionPolicy::Context::Values>( std::make_shared<ExecutionPolicy::Context::Values>(function_context.size())});
function_context.size())});
} }
switch (p_mesh->dimension()) { using MeshType = Mesh<Connectivity<Dimension>>;
case 1: {
throw NotImplementedError("not implemented in 1d");
break;
}
case 2: {
throw NotImplementedError("not implemented in 2d");
break;
}
case 3: {
using MeshType = Mesh<Connectivity3D>;
const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);
NodeValue<const TinyVector<3>> given_xr = given_mesh.xr(); NodeValue<const TinyVector<Dimension>> given_xr = given_mesh.xr();
NodeValue<TinyVector<3>> xr(given_mesh.connectivity()); NodeValue<TinyVector<Dimension>> xr(given_mesh.connectivity());
using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;
Kokkos::Experimental::UniqueTokenScope::Global>
tokens;
parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, &tokens](NodeId r) {
&tokens](NodeId r) {
const int32_t t = tokens.acquire(); const int32_t t = tokens.acquire();
auto& execution_policy = context_list[t]; auto& execution_policy = context_list[t];
execution_policy.currentContext()[0] = given_xr[r];
auto&& value = function_expression.execute(execution_policy);
AggregateDataVariant& v = std::get<AggregateDataVariant>(value); convertArgs(execution_policy.currentContext(), given_xr[r]);
xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])};
xr[r] = convert_result(function_expression.execute(execution_policy));
tokens.release(t); tokens.release(t);
}); });
return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
} }
};
MeshModule::MeshModule()
{
this->_addTypeDescriptor(
std::make_shared<TypeDescriptor>(ast_node_data_type_from<std::shared_ptr<IMesh>>.typeName()));
this->_addBuiltinFunction("readGmsh", std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::string>>(
std::function<std::shared_ptr<IMesh>(std::string)>{
[](const std::string& file_name) -> std::shared_ptr<IMesh> {
GmshReader gmsh_reader(file_name);
return gmsh_reader.mesh();
}}
));
this->_addBuiltinFunction("transform",
std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>,
FunctionSymbolId>>(
std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
[](std::shared_ptr<IMesh> p_mesh,
FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
switch (p_mesh->dimension()) {
case 1: {
return FunctionAdapter<TinyVector<1>, TinyVector<1>>::transform<1>(function_id,
p_mesh);
}
case 2: {
return FunctionAdapter<TinyVector<2>, TinyVector<2>>::transform<2>(function_id,
p_mesh);
}
case 3: {
return FunctionAdapter<TinyVector<3>, TinyVector<3>>::transform<3>(function_id,
p_mesh);
}
default: { default: {
return nullptr; throw NormalError("invalid dimension");
} }
} }
}} }}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment