From 6a15ee1c7f4e59e1c10e98ef56947fdc1cd07d04 Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Fri, 1 May 2020 17:14:22 +0200 Subject: [PATCH] Define a first mechanism to call user functions in C++ This is a concept proof. It requires many improvements: - code is crappy - only works for one kind of function (R^3->R^3 with a result defined as a list of 3 R) - no check is performed - execution context (necessary for multi-threading is not properly defined) Also, one should define an handier way to define such mechanism --- ...STNodeBuiltinFunctionExpressionBuilder.cpp | 5 +- src/language/BuiltinFunctionEmbedder.hpp | 8 ---- src/language/DataVariant.hpp | 2 + src/language/FunctionSymbolId.hpp | 48 +++++++++++++++++++ src/language/FunctionTable.hpp | 7 --- src/language/MeshModule.cpp | 45 ++++++++++------- .../FunctionArgumentConverter.hpp | 20 ++++++++ 7 files changed, 102 insertions(+), 33 deletions(-) create mode 100644 src/language/FunctionSymbolId.hpp diff --git a/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp b/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp index 17704a98d..332f957e4 100644 --- a/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeBuiltinFunctionExpressionBuilder.cpp @@ -79,7 +79,10 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData auto get_function_argument_to_function_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { switch (argument_node_sub_data_type.m_data_type) { case ASTNodeDataType::function_t: { - return std::make_unique<FunctionArgumentConverter<FunctionId, FunctionId>>(argument_number); + const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node; + auto symbol_table = parent_node.m_symbol_table; + + return std::make_unique<FunctionArgumentToFunctionSymbolIdConverter>(argument_number, symbol_table); } // LCOV_EXCL_START default: { diff --git a/src/language/BuiltinFunctionEmbedder.hpp b/src/language/BuiltinFunctionEmbedder.hpp index 31517c95b..79eb78e4c 100644 --- a/src/language/BuiltinFunctionEmbedder.hpp +++ b/src/language/BuiltinFunctionEmbedder.hpp @@ -76,14 +76,6 @@ class BuiltinFunctionEmbedder : public IBuiltinFunctionEmbedder } else { throw UnexpectedError("unexpected argument types while casting: expecting EmbeddedData"); } - } else if constexpr (std::is_same_v<Ti_Type, FunctionId>) { - if constexpr (std::is_same_v<Vi_Type, uint64_t>) { - std::get<I>(t) = FunctionId{v_i}; - throw NotImplementedError( - "Should get better descriptor, FunctionId should at least refer to the symbol table."); - } else { - throw UnexpectedError("unexpected argument types while casting: expecting uint64"); - } } else { throw UnexpectedError("Unexpected argument types while casting " + demangle<Vi_Type>() + " -> " + demangle<Ti_Type>()); diff --git a/src/language/DataVariant.hpp b/src/language/DataVariant.hpp index fd2e69b9b..04dac5293 100644 --- a/src/language/DataVariant.hpp +++ b/src/language/DataVariant.hpp @@ -3,6 +3,7 @@ #include <algebra/TinyVector.hpp> #include <language/EmbeddedData.hpp> +#include <language/FunctionSymbolId.hpp> #include <utils/PugsAssert.hpp> #include <tuple> @@ -19,6 +20,7 @@ using DataVariant = std::variant<std::monostate, std::string, EmbeddedData, AggregateDataVariant, + FunctionSymbolId, TinyVector<1>, TinyVector<2>, TinyVector<3>>; diff --git a/src/language/FunctionSymbolId.hpp b/src/language/FunctionSymbolId.hpp new file mode 100644 index 000000000..d6f0973f7 --- /dev/null +++ b/src/language/FunctionSymbolId.hpp @@ -0,0 +1,48 @@ +#ifndef FUNCTION_SYMBOL_ID_HPP +#define FUNCTION_SYMBOL_ID_HPP + +#include <utils/PugsAssert.hpp> +#include <utils/PugsMacros.hpp> + +#include <cstddef> +#include <iostream> +#include <memory> + +class SymbolTable; +class FunctionSymbolId +{ + private: + uint64_t m_function_id; + std::shared_ptr<SymbolTable> m_symbol_table = nullptr; + + public: + PUGS_INLINE uint64_t + id() const noexcept + { + return m_function_id; + } + + PUGS_INLINE + const SymbolTable& + symbolTable() const + { + Assert(m_symbol_table, "FunctionSymbolId is not initialized properly"); + return *m_symbol_table; + } + + friend std::ostream& + operator<<(std::ostream& os, const FunctionSymbolId& function_symbol_id) + { + os << function_symbol_id.m_function_id; + return os; + } + + FunctionSymbolId() = default; + FunctionSymbolId(uint64_t function_id, const std::shared_ptr<SymbolTable>& symbol_table) + : m_function_id(function_id), m_symbol_table(symbol_table) + {} + + ~FunctionSymbolId() = default; +}; + +#endif // FUNCTION_SYMBOL_ID_HPP diff --git a/src/language/FunctionTable.hpp b/src/language/FunctionTable.hpp index 72b19f437..a2b7b95af 100644 --- a/src/language/FunctionTable.hpp +++ b/src/language/FunctionTable.hpp @@ -8,13 +8,6 @@ #include <pegtl/position.hpp> -#include <iostream> - -struct FunctionId -{ - uint64_t m_function_id; -}; - class FunctionDescriptor { std::unique_ptr<ASTNode> m_domain_mapping_node; diff --git a/src/language/MeshModule.cpp b/src/language/MeshModule.cpp index 5c1721fa9..696799455 100644 --- a/src/language/MeshModule.cpp +++ b/src/language/MeshModule.cpp @@ -2,19 +2,19 @@ #include <language/BuiltinFunctionEmbedder.hpp> #include <language/FunctionTable.hpp> +#include <language/SymbolTable.hpp> #include <language/TypeDescriptor.hpp> +#include <language/node_processor/ExecutionPolicy.hpp> #include <mesh/Connectivity.hpp> #include <mesh/GmshReader.hpp> #include <mesh/Mesh.hpp> #include <utils/Exceptions.hpp> -#include <output/VTKWriter.hpp> - template <> inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"}; template <> -inline ASTNodeDataType ast_node_data_type_from<FunctionId> = {ASTNodeDataType::function_t}; +inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t}; MeshModule::MeshModule() { @@ -34,10 +34,16 @@ MeshModule::MeshModule() this ->_addBuiltinFunction("transform", std::make_shared< - BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionId>>( - std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionId)>{ + BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionSymbolId>>( + std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{ + + [](std::shared_ptr<IMesh> p_mesh, + FunctionSymbolId function_id) -> std::shared_ptr<IMesh> { + auto& symbol_table = function_id.symbolTable(); + auto& function_expression = + *symbol_table.functionTable()[function_id.id()].definitionNode().children[1]; + auto& function_context = function_expression.m_symbol_table->context(); - [](std::shared_ptr<IMesh> p_mesh, FunctionId function_id) -> std::shared_ptr<IMesh> { switch (p_mesh->dimension()) { case 1: { throw NotImplementedError("not implemented in 1d"); @@ -48,23 +54,28 @@ MeshModule::MeshModule() break; } case 3: { - std::cout << "Using function " << function_id.m_function_id << '\n'; - using MeshType = Mesh<Connectivity3D>; const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); NodeValue<const TinyVector<3>> given_xr = given_mesh.xr(); NodeValue<TinyVector<3>> xr(given_mesh.connectivity()); - parallel_for( - given_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) { - const TinyVector<3> shift{0, 0.05, 0.05}; - const auto x = given_xr[r] - shift; - const double c = std::cos(0.01 * x[0]); - const double s = std::sin(0.01 * x[0]); - const TinyVector<3> transformed{x[0], x[1] * c + x[2] * s, -x[1] * s + x[2] * c}; - xr[r] = transformed + shift; - }); + parallel_for(given_mesh.numberOfNodes(), [=, &function_expression, + &function_context](NodeId r) { + ExecutionPolicy::Context context{function_context.id(), + std::make_shared<ExecutionPolicy::Context::Values>( + function_context.size())}; + + ExecutionPolicy execution_policy; + ExecutionPolicy context_execution_policy{execution_policy, context}; + + context_execution_policy.currentContext()[0] = given_xr[r]; + + auto&& value = function_expression.execute(context_execution_policy); + + AggregateDataVariant& v = std::get<AggregateDataVariant>(value); + xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])}; + }); return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); } diff --git a/src/language/node_processor/FunctionArgumentConverter.hpp b/src/language/node_processor/FunctionArgumentConverter.hpp index 990fcde28..467bc7b24 100644 --- a/src/language/node_processor/FunctionArgumentConverter.hpp +++ b/src/language/node_processor/FunctionArgumentConverter.hpp @@ -43,4 +43,24 @@ class FunctionArgumentConverter final : public IFunctionArgumentConverter FunctionArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} }; +class FunctionArgumentToFunctionSymbolIdConverter final : public IFunctionArgumentConverter +{ + private: + size_t m_argument_id; + std::shared_ptr<SymbolTable> m_symbol_table; + + public: + DataVariant + convert(ExecutionPolicy& exec_policy, DataVariant&& value) + { + exec_policy.currentContext()[m_argument_id] = FunctionSymbolId{std::get<uint64_t>(value), m_symbol_table}; + + return {}; + } + + FunctionArgumentToFunctionSymbolIdConverter(size_t argument_id, const std::shared_ptr<SymbolTable>& symbol_table) + : m_argument_id{argument_id}, m_symbol_table{symbol_table} + {} +}; + #endif // FUNCTION_ARGUMENT_CONVERTER_HPP -- GitLab