Skip to content
Snippets Groups Projects
Commit 6a15ee1c authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Define a first mechanism to call user functions in C++

This is a concept proof. It requires many improvements:
- code is crappy
- only works for one kind of function (R^3->R^3 with a result defined
as a list of 3 R)
- no check is performed
- execution context (necessary for multi-threading  is not properly
defined)

Also, one should define an handier way to define such mechanism
parent 6bda3939
No related branches found
No related tags found
1 merge request!37Feature/language
...@@ -79,7 +79,10 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData ...@@ -79,7 +79,10 @@ ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeData
auto get_function_argument_to_function_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> { auto get_function_argument_to_function_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
switch (argument_node_sub_data_type.m_data_type) { switch (argument_node_sub_data_type.m_data_type) {
case ASTNodeDataType::function_t: { case ASTNodeDataType::function_t: {
return std::make_unique<FunctionArgumentConverter<FunctionId, FunctionId>>(argument_number); const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node;
auto symbol_table = parent_node.m_symbol_table;
return std::make_unique<FunctionArgumentToFunctionSymbolIdConverter>(argument_number, symbol_table);
} }
// LCOV_EXCL_START // LCOV_EXCL_START
default: { default: {
......
...@@ -76,14 +76,6 @@ class BuiltinFunctionEmbedder : public IBuiltinFunctionEmbedder ...@@ -76,14 +76,6 @@ class BuiltinFunctionEmbedder : public IBuiltinFunctionEmbedder
} else { } else {
throw UnexpectedError("unexpected argument types while casting: expecting EmbeddedData"); throw UnexpectedError("unexpected argument types while casting: expecting EmbeddedData");
} }
} else if constexpr (std::is_same_v<Ti_Type, FunctionId>) {
if constexpr (std::is_same_v<Vi_Type, uint64_t>) {
std::get<I>(t) = FunctionId{v_i};
throw NotImplementedError(
"Should get better descriptor, FunctionId should at least refer to the symbol table.");
} else {
throw UnexpectedError("unexpected argument types while casting: expecting uint64");
}
} else { } else {
throw UnexpectedError("Unexpected argument types while casting " + demangle<Vi_Type>() + " -> " + throw UnexpectedError("Unexpected argument types while casting " + demangle<Vi_Type>() + " -> " +
demangle<Ti_Type>()); demangle<Ti_Type>());
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <algebra/TinyVector.hpp> #include <algebra/TinyVector.hpp>
#include <language/EmbeddedData.hpp> #include <language/EmbeddedData.hpp>
#include <language/FunctionSymbolId.hpp>
#include <utils/PugsAssert.hpp> #include <utils/PugsAssert.hpp>
#include <tuple> #include <tuple>
...@@ -19,6 +20,7 @@ using DataVariant = std::variant<std::monostate, ...@@ -19,6 +20,7 @@ using DataVariant = std::variant<std::monostate,
std::string, std::string,
EmbeddedData, EmbeddedData,
AggregateDataVariant, AggregateDataVariant,
FunctionSymbolId,
TinyVector<1>, TinyVector<1>,
TinyVector<2>, TinyVector<2>,
TinyVector<3>>; TinyVector<3>>;
......
#ifndef FUNCTION_SYMBOL_ID_HPP
#define FUNCTION_SYMBOL_ID_HPP
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>
#include <cstddef>
#include <iostream>
#include <memory>
class SymbolTable;
class FunctionSymbolId
{
private:
uint64_t m_function_id;
std::shared_ptr<SymbolTable> m_symbol_table = nullptr;
public:
PUGS_INLINE uint64_t
id() const noexcept
{
return m_function_id;
}
PUGS_INLINE
const SymbolTable&
symbolTable() const
{
Assert(m_symbol_table, "FunctionSymbolId is not initialized properly");
return *m_symbol_table;
}
friend std::ostream&
operator<<(std::ostream& os, const FunctionSymbolId& function_symbol_id)
{
os << function_symbol_id.m_function_id;
return os;
}
FunctionSymbolId() = default;
FunctionSymbolId(uint64_t function_id, const std::shared_ptr<SymbolTable>& symbol_table)
: m_function_id(function_id), m_symbol_table(symbol_table)
{}
~FunctionSymbolId() = default;
};
#endif // FUNCTION_SYMBOL_ID_HPP
...@@ -8,13 +8,6 @@ ...@@ -8,13 +8,6 @@
#include <pegtl/position.hpp> #include <pegtl/position.hpp>
#include <iostream>
struct FunctionId
{
uint64_t m_function_id;
};
class FunctionDescriptor class FunctionDescriptor
{ {
std::unique_ptr<ASTNode> m_domain_mapping_node; std::unique_ptr<ASTNode> m_domain_mapping_node;
......
...@@ -2,19 +2,19 @@ ...@@ -2,19 +2,19 @@
#include <language/BuiltinFunctionEmbedder.hpp> #include <language/BuiltinFunctionEmbedder.hpp>
#include <language/FunctionTable.hpp> #include <language/FunctionTable.hpp>
#include <language/SymbolTable.hpp>
#include <language/TypeDescriptor.hpp> #include <language/TypeDescriptor.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <mesh/Connectivity.hpp> #include <mesh/Connectivity.hpp>
#include <mesh/GmshReader.hpp> #include <mesh/GmshReader.hpp>
#include <mesh/Mesh.hpp> #include <mesh/Mesh.hpp>
#include <utils/Exceptions.hpp> #include <utils/Exceptions.hpp>
#include <output/VTKWriter.hpp>
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"}; inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"};
template <> template <>
inline ASTNodeDataType ast_node_data_type_from<FunctionId> = {ASTNodeDataType::function_t}; inline ASTNodeDataType ast_node_data_type_from<FunctionSymbolId> = {ASTNodeDataType::function_t};
MeshModule::MeshModule() MeshModule::MeshModule()
{ {
...@@ -34,10 +34,16 @@ MeshModule::MeshModule() ...@@ -34,10 +34,16 @@ MeshModule::MeshModule()
this this
->_addBuiltinFunction("transform", ->_addBuiltinFunction("transform",
std::make_shared< std::make_shared<
BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionId>>( BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>, FunctionSymbolId>>(
std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionId)>{ std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
[](std::shared_ptr<IMesh> p_mesh,
FunctionSymbolId function_id) -> std::shared_ptr<IMesh> {
auto& symbol_table = function_id.symbolTable();
auto& function_expression =
*symbol_table.functionTable()[function_id.id()].definitionNode().children[1];
auto& function_context = function_expression.m_symbol_table->context();
[](std::shared_ptr<IMesh> p_mesh, FunctionId function_id) -> std::shared_ptr<IMesh> {
switch (p_mesh->dimension()) { switch (p_mesh->dimension()) {
case 1: { case 1: {
throw NotImplementedError("not implemented in 1d"); throw NotImplementedError("not implemented in 1d");
...@@ -48,22 +54,27 @@ MeshModule::MeshModule() ...@@ -48,22 +54,27 @@ MeshModule::MeshModule()
break; break;
} }
case 3: { case 3: {
std::cout << "Using function " << function_id.m_function_id << '\n';
using MeshType = Mesh<Connectivity3D>; using MeshType = Mesh<Connectivity3D>;
const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh); const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);
NodeValue<const TinyVector<3>> given_xr = given_mesh.xr(); NodeValue<const TinyVector<3>> given_xr = given_mesh.xr();
NodeValue<TinyVector<3>> xr(given_mesh.connectivity()); NodeValue<TinyVector<3>> xr(given_mesh.connectivity());
parallel_for( parallel_for(given_mesh.numberOfNodes(), [=, &function_expression,
given_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) { &function_context](NodeId r) {
const TinyVector<3> shift{0, 0.05, 0.05}; ExecutionPolicy::Context context{function_context.id(),
const auto x = given_xr[r] - shift; std::make_shared<ExecutionPolicy::Context::Values>(
const double c = std::cos(0.01 * x[0]); function_context.size())};
const double s = std::sin(0.01 * x[0]);
const TinyVector<3> transformed{x[0], x[1] * c + x[2] * s, -x[1] * s + x[2] * c}; ExecutionPolicy execution_policy;
xr[r] = transformed + shift; ExecutionPolicy context_execution_policy{execution_policy, context};
context_execution_policy.currentContext()[0] = given_xr[r];
auto&& value = function_expression.execute(context_execution_policy);
AggregateDataVariant& v = std::get<AggregateDataVariant>(value);
xr[r] = {std::get<double>(v[0]), std::get<double>(v[1]), std::get<double>(v[2])};
}); });
return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr); return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
......
...@@ -43,4 +43,24 @@ class FunctionArgumentConverter final : public IFunctionArgumentConverter ...@@ -43,4 +43,24 @@ class FunctionArgumentConverter final : public IFunctionArgumentConverter
FunctionArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {} FunctionArgumentConverter(size_t argument_id) : m_argument_id{argument_id} {}
}; };
class FunctionArgumentToFunctionSymbolIdConverter final : public IFunctionArgumentConverter
{
private:
size_t m_argument_id;
std::shared_ptr<SymbolTable> m_symbol_table;
public:
DataVariant
convert(ExecutionPolicy& exec_policy, DataVariant&& value)
{
exec_policy.currentContext()[m_argument_id] = FunctionSymbolId{std::get<uint64_t>(value), m_symbol_table};
return {};
}
FunctionArgumentToFunctionSymbolIdConverter(size_t argument_id, const std::shared_ptr<SymbolTable>& symbol_table)
: m_argument_id{argument_id}, m_symbol_table{symbol_table}
{}
};
#endif // FUNCTION_ARGUMENT_CONVERTER_HPP #endif // FUNCTION_ARGUMENT_CONVERTER_HPP
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment