Skip to content
Snippets Groups Projects
Commit 63715ef8 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Test simple context strategy: only works for one thread!!

parent 6a15ee1c
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -10,12 +10,44 @@ ...@@ -10,12 +10,44 @@
#include <mesh/Mesh.hpp> #include <mesh/Mesh.hpp>
#include <utils/Exceptions.hpp> #include <utils/Exceptions.hpp>
#include <Kokkos_Core.hpp>
#include <cstdio>
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"}; inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"};
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
struct TagA
{
};
struct TagC
{
};
struct Foo
{
static std::vector<int> m_ids;
KOKKOS_INLINE_FUNCTION
void
operator()(const TagA, const Kokkos::TeamPolicy<>::member_type& team) const
{
m_ids[team.league_rank()] = team.league_rank() * 3;
}
KOKKOS_INLINE_FUNCTION
void
operator()(const TagC, const Kokkos::TeamPolicy<>::member_type& team) const
{
m_ids[team.league_rank()] = team.league_rank() * 2;
}
Foo() {}
};
inline std::vector<int> Foo::m_ids;
MeshModule::MeshModule() MeshModule::MeshModule()
{ {
this->_addTypeDescriptor( this->_addTypeDescriptor(
...@@ -31,10 +63,9 @@ MeshModule::MeshModule() ...@@ -31,10 +63,9 @@ MeshModule::MeshModule()
)); ));
this this->_addBuiltinFunction("transform",
->_addBuiltinFunction("transform", std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>,
std::make_shared< FunctionSymbolId>>(
BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionSymbolId>>(
std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{ std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
[](std::shared_ptr<IMesh> p_mesh, [](std::shared_ptr<IMesh> p_mesh,
...@@ -44,6 +75,32 @@ MeshModule::MeshModule() ...@@ -44,6 +75,32 @@ MeshModule::MeshModule()
*symbol_table.functionTable()[function_id.id()].definitionNode().children[1]; *symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
auto& function_context = function_expression.m_symbol_table->context(); auto& function_context = function_expression.m_symbol_table->context();
// Foo foo;
// Foo::m_ids.resize(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
// Kokkos::parallel_for(Kokkos::TeamPolicy<
// TagA>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
// 1),
// foo);
// std::cout << "--------\n" << std::endl;
// for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
// std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 3 <<
// std::endl;
// }
// std::cout << "********\n" << std::endl;
// Kokkos::parallel_for(Kokkos::TeamPolicy<
// TagC>(Kokkos::DefaultExecutionSpace::impl_thread_pool_size(),
// 1),
// foo);
// std::cout << "--------\n" << std::endl;
// for (int i = 0; i < Kokkos::DefaultExecutionSpace::impl_thread_pool_size(); ++i) {
// std::cout << "m_ids[" << i << "] = " << Foo::m_ids[i] << " | " << i * 2 <<
// std::endl;
// }
switch (p_mesh->dimension()) { switch (p_mesh->dimension()) {
case 1: { case 1: {
throw NotImplementedError("not implemented in 1d"); throw NotImplementedError("not implemented in 1d");
...@@ -60,8 +117,6 @@ MeshModule::MeshModule() ...@@ -60,8 +117,6 @@ MeshModule::MeshModule()
NodeValue<TinyVector<3>> xr(given_mesh.connectivity()); NodeValue<TinyVector<3>> xr(given_mesh.connectivity());
parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
&function_context](NodeId r) {
ExecutionPolicy::Context context{function_context.id(), ExecutionPolicy::Context context{function_context.id(),
std::make_shared<ExecutionPolicy::Context::Values>( std::make_shared<ExecutionPolicy::Context::Values>(
function_context.size())}; function_context.size())};
...@@ -69,6 +124,8 @@ MeshModule::MeshModule() ...@@ -69,6 +124,8 @@ MeshModule::MeshModule()
ExecutionPolicy execution_policy; ExecutionPolicy execution_policy;
ExecutionPolicy context_execution_policy{execution_policy, context}; ExecutionPolicy context_execution_policy{execution_policy, context};
parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
&context_execution_policy](NodeId r) {
context_execution_policy.currentContext()[0] = given_xr[r]; context_execution_policy.currentContext()[0] = given_xr[r];
auto&& value = function_expression.execute(context_execution_policy); auto&& value = function_expression.execute(context_execution_policy);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment