Skip to content
Snippets Groups Projects
Commit cc7c4fca authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Clean function evaluation code

- use Kokkos::UniqueToken manager to use proper context in
multi-thread context
- still works only for R^3->R^3 functions

Remains to
- check input/output types
- treat the case when the function returns an R^3 and not a list of 3 R
parent 0ce1144f
No related branches found
No related tags found
1 merge request!37Feature/language
......@@ -19,35 +19,6 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod
template <>
inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
struct TagA
{
};
struct TagC
{
};
struct Foo
{
static std::vector<int> m_ids;
KOKKOS_INLINE_FUNCTION
void
operator()(const TagA, const Kokkos::TeamPolicy<>::member_type& team) const
{
m_ids[team.league_rank()] = team.league_rank() * 3;
}
KOKKOS_INLINE_FUNCTION
void
operator()(const TagC, const Kokkos::TeamPolicy<>::member_type& team) const
{
m_ids[team.league_rank()] = team.league_rank() * 2;
}
Foo() {}
};
inline std::vector<int> Foo::m_ids;
MeshModule::MeshModule()
{
this->_addTypeDescriptor(
......@@ -67,7 +38,6 @@ MeshModule::MeshModule()
std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>,
FunctionSymbolId>>(
std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
[](std::shared_ptr<IMesh> p_mesh,
FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
auto& symbol_table = function_id.symbolTable();
......@@ -75,31 +45,15 @@ MeshModule::MeshModule()
*symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
auto& function_context = function_expression.m_symbol_table->context();
// Foo foo;
// Foo::m_ids.resize(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
// Kokkos::parallel_for(Kokkos::TeamPolicy<
// TagA>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
// 1),
// foo);
// std::cout << "--------\n" << std::endl;
// for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
// std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 3 <<
// std::endl;
// }
// std::cout << "********\n" << std::endl;
// Kokkos::parallel_for(Kokkos::TeamPolicy<
// TagC>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
// 1),
// foo);
// std::cout << "--------\n" << std::endl;
// for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
// std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 2 <<
// std::endl;
// }
const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size();
Array<ExecutionPolicy> context_list(number_of_threads);
for (size_t i = 0; i < context_list.size(); ++i) {
context_list[i] =
ExecutionPolicy(ExecutionPolicy{},
{function_context.id(),
std::make_shared<ExecutionPolicy::Context::Values>(
function_context.size())});
}
switch (p_mesh->dimension()) {
case 1: {
......@@ -117,21 +71,24 @@ MeshModule::MeshModule()
NodeValue<TinyVector<3>> xr(given_mesh.connectivity());
ExecutionPolicy::Context context{function_context.id(),
std::make_shared<ExecutionPolicy::Context::Values>(
function_context.size())};
ExecutionPolicy execution_policy;
ExecutionPolicy context_execution_policy{execution_policy, context};
using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
Kokkos::Experimental::UniqueToken<execution_space,
Kokkos::Experimental::UniqueTokenScope::Global>
tokens;
parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
&context_execution_policy](NodeId r) {
context_execution_policy.currentContext()[0] = given_xr[r];
&tokens](NodeId r) {
const int32_t t = tokens.acquire();
auto&& value = function_expression.execute(context_execution_policy);
auto& execution_policy = context_list[t];
execution_policy.currentContext()[0] = given_xr[r];
auto&& value = function_expression.execute(execution_policy);
AggregateDataVariant& v = std::get<AggregateDataVariant>(value);
xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])};
tokens.release(t);
});
return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment