Skip to content
Snippets Groups Projects
Commit caab4b21 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Improve the code (design, genericity and readability)

- Rename FunctionAdapter -> PugsFunctionAdapter to improve readability
- PugsFunctionAdapter is now defined in its own file
parent 8ed59aba
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <language/BuiltinFunctionEmbedder.hpp> #include <language/BuiltinFunctionEmbedder.hpp>
#include <language/FunctionTable.hpp> #include <language/FunctionTable.hpp>
#include <language/PugsFunctionAdapter.hpp>
#include <language/SymbolTable.hpp> #include <language/SymbolTable.hpp>
#include <language/TypeDescriptor.hpp> #include <language/TypeDescriptor.hpp>
#include <language/node_processor/ExecutionPolicy.hpp> #include <language/node_processor/ExecutionPolicy.hpp>
...@@ -19,97 +20,39 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod ...@@ -19,97 +20,39 @@ inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNod
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
template <typename T>
class MeshTransformation;
template <typename OutputType, typename... InputType> template <typename OutputType, typename... InputType>
class FunctionAdapter class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)>
{ {
static constexpr size_t OutputDimension = OutputType::Dimension; static constexpr size_t Dimension = OutputType::Dimension;
using Adapter = PugsFunctionAdapter<OutputType(InputType...)>;
private:
template <typename T, typename... Args>
static void
_convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context)
{
context[sizeof...(args)] = t;
if constexpr (sizeof...(args) > 0) {
_convertArgs(std::forward<Args>(args)..., context);
}
}
template <typename... Args>
static void
convertArgs(ExecutionPolicy::Context& context, const Args&... args)
{
static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
_convertArgs(args..., context);
}
static std::function<OutputType(DataVariant&& result)>
_get_result_converter(ASTNodeDataType data_type)
{
switch (data_type) {
case ASTNodeDataType::list_t: {
return [](DataVariant&& result) -> OutputType {
AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
OutputType x;
for (size_t i = 0; i < x.dimension(); ++i) {
x[i] = std::get<double>(v[i]);
}
return x;
};
}
case ASTNodeDataType::vector_t: {
return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
}
case ASTNodeDataType::double_t: {
if constexpr (OutputDimension == 1) {
return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
} else {
throw UnexpectedError("unexpected data_type");
}
}
default: {
throw UnexpectedError("unexpected data_type");
}
}
}
public: public:
template <size_t Dimension> static inline std::shared_ptr<Mesh<Connectivity<Dimension>>>
static inline std::shared_ptr<Mesh<Connectivity<OutputDimension>>>
transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh) transform(FunctionSymbolId function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
{ {
auto& symbol_table = function_symbol_id.symbolTable();
auto& function_expression = *symbol_table.functionTable()[function_symbol_id.id()].definitionNode().children[1];
auto& function_context = function_expression.m_symbol_table->context();
ASTNodeDataType t = function_expression.m_data_type;
auto convert_result = _get_result_converter(t);
const auto number_of_threads = Kokkos::DefaultExecutionSpace::impl_thread_pool_size();
Array<ExecutionPolicy> context_list(number_of_threads);
for (size_t i = 0; i < context_list.size(); ++i) {
context_list[i] = ExecutionPolicy(ExecutionPolicy{},
{function_context.id(),
std::make_shared<ExecutionPolicy::Context::Values>(function_context.size())});
}
using MeshType = Mesh<Connectivity<Dimension>>; using MeshType = Mesh<Connectivity<Dimension>>;
const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);
NodeValue<const TinyVector<Dimension>> given_xr = given_mesh.xr();
NodeValue<TinyVector<Dimension>> xr(given_mesh.connectivity()); auto& expression = Adapter::getFunctionExpression(function_symbol_id);
auto convert_result = Adapter::getResultConverter(expression.m_data_type);
Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);
NodeValue<const OutputType> given_xr = given_mesh.xr();
NodeValue<OutputType> xr(given_mesh.connectivity());
using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space; using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens; Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;
parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, &tokens](NodeId r) { parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) {
const int32_t t = tokens.acquire(); const int32_t t = tokens.acquire();
auto& execution_policy = context_list[t]; auto& execution_policy = context_list[t];
convertArgs(execution_policy.currentContext(), given_xr[r]); Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]);
xr[r] = convert_result(expression.execute(execution_policy));
xr[r] = convert_result(function_expression.execute(execution_policy));
tokens.release(t); tokens.release(t);
}); });
...@@ -140,19 +83,19 @@ MeshModule::MeshModule() ...@@ -140,19 +83,19 @@ MeshModule::MeshModule()
FunctionSymbolId function_id) -> std::shared_ptr<IMesh> { FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
switch (p_mesh->dimension()) { switch (p_mesh->dimension()) {
case 1: { case 1: {
return FunctionAdapter<TinyVector<1>, TinyVector<1>>::transform<1>(function_id, using TransformT = TinyVector<1>(TinyVector<1>);
p_mesh); return MeshTransformation<TransformT>::transform(function_id, p_mesh);
} }
case 2: { case 2: {
return FunctionAdapter<TinyVector<2>, TinyVector<2>>::transform<2>(function_id, using TransformT = TinyVector<2>(TinyVector<2>);
p_mesh); return MeshTransformation<TransformT>::transform(function_id, p_mesh);
} }
case 3: { case 3: {
return FunctionAdapter<TinyVector<3>, TinyVector<3>>::transform<3>(function_id, using TransformT = TinyVector<3>(TinyVector<3>);
p_mesh); return MeshTransformation<TransformT>::transform(function_id, p_mesh);
} }
default: { default: {
throw NormalError("invalid dimension"); throw NormalError("invalid mesh dimension");
} }
} }
}} }}
......
#ifndef PUGS_FUNCTION_ADAPTER_HPP
#define PUGS_FUNCTION_ADAPTER_HPP
#include <language/ASTNode.hpp>
#include <language/SymbolTable.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <utils/Array.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsMacros.hpp>
#include <Kokkos_Core.hpp>
template <typename T>
class PugsFunctionAdapter;
template <typename OutputType, typename... InputType>
class PugsFunctionAdapter<OutputType(InputType...)>
{
private:
template <typename T, typename... Args>
PUGS_INLINE static void
_convertArgs(const Args&&... args, const T& t, ExecutionPolicy::Context& context)
{
context[sizeof...(args)] = t;
if constexpr (sizeof...(args) > 0) {
_convertArgs(std::forward<Args>(args)..., context);
}
}
protected:
PUGS_INLINE static auto&
getFunctionExpression(FunctionSymbolId function_symbol_id)
{
return *function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()].definitionNode().children[1];
}
PUGS_INLINE static auto
getContextList(const ASTNode& expression)
{
Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
auto& context = expression.m_symbol_table->context();
for (size_t i = 0; i < context_list.size(); ++i) {
context_list[i] =
ExecutionPolicy(ExecutionPolicy{},
{context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
}
return context_list;
}
template <typename... Args>
PUGS_INLINE static void
convertArgs(ExecutionPolicy::Context& context, const Args&... args)
{
static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
_convertArgs(args..., context);
}
PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
getResultConverter(ASTNodeDataType data_type)
{
switch (data_type) {
case ASTNodeDataType::list_t: {
return [](DataVariant&& result) -> OutputType {
AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
OutputType x;
for (size_t i = 0; i < x.dimension(); ++i) {
x[i] = std::get<double>(v[i]);
}
return x;
};
}
case ASTNodeDataType::vector_t: {
return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
}
case ASTNodeDataType::double_t: {
if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
} else {
throw UnexpectedError("unexpected data_type");
}
}
default: {
throw UnexpectedError("unexpected data_type");
}
}
}
};
#endif // PUGS_FUNCTION_ADAPTER_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment