Skip to content
Snippets Groups Projects
Commit b6571c51 authored by Stéphane Del Pino's avatar Stéphane Del Pino
Browse files

Merge branch 'feature/discrete-function-algebra' into 'develop'

Feature/discrete function algebra

See merge request !81
parents b5d624f9 86b7fed3
No related branches found
No related tags found
1 merge request!81Feature/discrete function algebra
Showing
with 327 additions and 0 deletions
#include <language/ast/ASTModulesImporter.hpp>
#include <language/PEGGrammar.hpp>
#include <language/utils/OperatorRepository.hpp>
void
ASTModulesImporter::_importModule(ASTNode& import_node)
......@@ -20,6 +21,7 @@ ASTModulesImporter::_importModule(ASTNode& import_node)
std::cout << " * importing '" << rang::fgB::green << module_name << rang::style::reset << "' module\n";
m_module_repository.populateSymbolTable(module_name_node, m_symbol_table);
m_module_repository.registerOperators(module_name);
}
void
......@@ -37,6 +39,7 @@ ASTModulesImporter::_importAllModules(ASTNode& node)
ASTModulesImporter::ASTModulesImporter(ASTNode& root_node) : m_symbol_table{*root_node.m_symbol_table}
{
Assert(root_node.is_root());
OperatorRepository::instance().reset();
m_module_repository.populateMandatorySymbolTable(root_node, m_symbol_table);
this->_importAllModules(root_node);
......
#include <language/modules/CoreModule.hpp>
#include <language/utils/AffectationProcessorBuilder.hpp>
#include <language/utils/AffectationRegisterForB.hpp>
#include <language/utils/AffectationRegisterForN.hpp>
#include <language/utils/AffectationRegisterForR.hpp>
#include <language/utils/AffectationRegisterForRn.hpp>
#include <language/utils/AffectationRegisterForRnxn.hpp>
#include <language/utils/AffectationRegisterForString.hpp>
#include <language/utils/AffectationRegisterForZ.hpp>
#include <language/utils/BinaryOperatorRegisterForB.hpp>
#include <language/utils/BinaryOperatorRegisterForN.hpp>
#include <language/utils/BinaryOperatorRegisterForR.hpp>
#include <language/utils/BinaryOperatorRegisterForRn.hpp>
#include <language/utils/BinaryOperatorRegisterForRnxn.hpp>
#include <language/utils/BinaryOperatorRegisterForString.hpp>
#include <language/utils/BinaryOperatorRegisterForZ.hpp>
#include <language/utils/IncDecOperatorRegisterForN.hpp>
#include <language/utils/IncDecOperatorRegisterForR.hpp>
#include <language/utils/IncDecOperatorRegisterForZ.hpp>
#include <language/utils/UnaryOperatorRegisterForB.hpp>
#include <language/utils/UnaryOperatorRegisterForN.hpp>
#include <language/utils/UnaryOperatorRegisterForR.hpp>
#include <language/utils/UnaryOperatorRegisterForRn.hpp>
#include <language/utils/UnaryOperatorRegisterForRnxn.hpp>
#include <language/utils/UnaryOperatorRegisterForZ.hpp>
#include <language/modules/CoreModule.hpp>
#include <language/modules/ModuleRepository.hpp>
#include <language/utils/ASTExecutionInfo.hpp>
......@@ -42,3 +70,46 @@ CoreModule::CoreModule() : BuiltinModule(true)
));
}
void
CoreModule::registerOperators() const
{
AffectationRegisterForB{};
AffectationRegisterForN{};
AffectationRegisterForZ{};
AffectationRegisterForR{};
AffectationRegisterForRn<1>{};
AffectationRegisterForRn<2>{};
AffectationRegisterForRn<3>{};
AffectationRegisterForRnxn<1>{};
AffectationRegisterForRnxn<2>{};
AffectationRegisterForRnxn<3>{};
AffectationRegisterForString{};
BinaryOperatorRegisterForB{};
BinaryOperatorRegisterForN{};
BinaryOperatorRegisterForZ{};
BinaryOperatorRegisterForR{};
BinaryOperatorRegisterForRn<1>{};
BinaryOperatorRegisterForRn<2>{};
BinaryOperatorRegisterForRn<3>{};
BinaryOperatorRegisterForRnxn<1>{};
BinaryOperatorRegisterForRnxn<2>{};
BinaryOperatorRegisterForRnxn<3>{};
BinaryOperatorRegisterForString{};
IncDecOperatorRegisterForN{};
IncDecOperatorRegisterForR{};
IncDecOperatorRegisterForZ{};
UnaryOperatorRegisterForB{};
UnaryOperatorRegisterForN{};
UnaryOperatorRegisterForZ{};
UnaryOperatorRegisterForR{};
UnaryOperatorRegisterForRn<1>{};
UnaryOperatorRegisterForRn<2>{};
UnaryOperatorRegisterForRn<3>{};
UnaryOperatorRegisterForRnxn<1>{};
UnaryOperatorRegisterForRnxn<2>{};
UnaryOperatorRegisterForRnxn<3>{};
}
......@@ -12,6 +12,8 @@ class CoreModule : public BuiltinModule
return "core";
}
void registerOperators() const final;
CoreModule();
~CoreModule() = default;
};
......
......@@ -25,6 +25,8 @@ class IModule
virtual const NameTypeMap& getNameTypeMap() const = 0;
virtual void registerOperators() const = 0;
virtual std::string_view name() const = 0;
virtual ~IModule() = default;
......
......@@ -90,3 +90,7 @@ LinearSolverModule::LinearSolverModule()
));
}
void
LinearSolverModule::registerOperators() const
{}
......@@ -12,6 +12,8 @@ class LinearSolverModule : public BuiltinModule
return "linear_solver";
}
void registerOperators() const final;
LinearSolverModule();
~LinearSolverModule() = default;
};
......
......@@ -70,3 +70,7 @@ MathModule::MathModule()
this->_addBuiltinFunction("round", std::make_shared<BuiltinFunctionEmbedder<int64_t(double)>>(
[](double x) -> int64_t { return std::lround(x); }));
}
void
MathModule::registerOperators() const
{}
......@@ -12,6 +12,8 @@ class MathModule : public BuiltinModule
return "math";
}
void registerOperators() const final;
MathModule();
~MathModule() = default;
......
......@@ -212,3 +212,7 @@ MeshModule::MeshModule()
));
}
void
MeshModule::registerOperators() const
{}
......@@ -20,6 +20,8 @@ class MeshModule : public BuiltinModule
return "mesh";
}
void registerOperators() const final;
MeshModule();
~MeshModule() = default;
......
......@@ -71,6 +71,14 @@ ModuleRepository::populateSymbolTable(const ASTNode& module_name_node, SymbolTab
if (i_module != m_module_set.end()) {
const IModule& populating_module = *i_module->second;
if (populating_module.isMandatory()) {
std::ostringstream error_message;
error_message << "module '" << rang::fgB::blue << module_name << rang::style::reset << rang::style::bold
<< "' is an autoload " << rang::fgB::yellow << "mandatory" << rang::style::reset
<< rang::style::bold << " module. It cannot be imported explicitly!";
throw ParseError(error_message.str(), module_name_node.begin());
}
this->_populateEmbedderTableT(module_name_node, module_name, populating_module.getNameBuiltinFunctionMap(),
ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(), symbol_table,
symbol_table.builtinFunctionEmbedderTable());
......@@ -100,6 +108,8 @@ ModuleRepository::populateMandatorySymbolTable(const ASTNode& root_node, SymbolT
this->_populateEmbedderTableT(root_node, module_name, i_module->getNameTypeMap(),
ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>(), symbol_table,
symbol_table.typeEmbedderTable());
i_module->registerOperators();
}
}
}
......@@ -123,6 +133,17 @@ ModuleRepository::getAvailableModules() const
return os.str();
}
void
ModuleRepository::registerOperators(const std::string& module_name)
{
auto i_module = m_module_set.find(module_name);
if (i_module != m_module_set.end()) {
i_module->second->registerOperators();
} else {
throw NormalError(std::string{"could not find module "} + module_name);
}
}
std::string
ModuleRepository::getModuleInfo(const std::string& module_name) const
{
......
......@@ -29,6 +29,7 @@ class ModuleRepository
public:
void populateSymbolTable(const ASTNode& module_name_node, SymbolTable& symbol_table);
void populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table);
void registerOperators(const std::string& module_name);
std::string getAvailableModules() const;
std::string getModuleInfo(const std::string& module_name) const;
......
#include <language/modules/SchemeModule.hpp>
#include <language/utils/BinaryOperatorProcessorBuilder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/AcousticSolver.hpp>
......@@ -280,3 +283,81 @@ SchemeModule::SchemeModule()
));
}
void
SchemeModule::registerOperators() const
{
OperatorRepository& repository = OperatorRepository::instance();
repository.addBinaryOperator<language::plus_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::minus_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::divide_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
bool, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
int64_t, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
uint64_t, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
double, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<1>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<2>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
TinyMatrix<3>, std::shared_ptr<const IDiscreteFunction>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<1>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<2>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyVector<3>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<1>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<2>>>());
repository.addBinaryOperator<language::multiply_op>(
std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>,
std::shared_ptr<const IDiscreteFunction>, TinyMatrix<3>>>());
}
......@@ -34,6 +34,8 @@ class SchemeModule : public BuiltinModule
return "scheme";
}
void registerOperators() const final;
SchemeModule();
~SchemeModule() = default;
......
......@@ -65,3 +65,7 @@ UtilsModule::UtilsModule()
));
}
void
UtilsModule::registerOperators() const
{}
......@@ -12,6 +12,8 @@ class UtilsModule : public BuiltinModule
return "utils";
}
void registerOperators() const final;
UtilsModule();
~UtilsModule() = default;
};
......
......@@ -100,3 +100,7 @@ WriterModule::WriterModule()
));
}
void
WriterModule::registerOperators() const
{}
......@@ -29,6 +29,8 @@ class WriterModule : public BuiltinModule
return "writer";
}
void registerOperators() const final;
WriterModule();
~WriterModule() = default;
......
......@@ -6,9 +6,13 @@
#include <language/node_processor/BinaryExpressionProcessor.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/IBinaryOperatorProcessorBuilder.hpp>
#include <language/utils/ParseError.hpp>
#include <type_traits>
template <typename DataType>
class DataHandler;
template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT>
class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder
{
......@@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil
}
};
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>>
final : public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
const auto& embedded_a = std::get<EmbeddedData>(a);
const auto& embedded_b = std::get<EmbeddedData>(b);
std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();
std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr)));
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final
: public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) {
const auto& a_value = std::get<A_DataT>(a);
const auto& embedded_b = std::get<EmbeddedData>(b);
std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr)));
} else {
static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type");
}
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final
: public INodeProcessor
{
private:
ASTNode& m_node;
PUGS_INLINE DataVariant
_eval(const DataVariant& a, const DataVariant& b)
{
if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) {
const auto& embedded_a = std::get<EmbeddedData>(a);
const auto& b_value = std::get<B_DataT>(b);
std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();
return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value)));
} else {
static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type");
}
}
public:
DataVariant
execute(ExecutionPolicy& exec_policy)
{
try {
return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
}
catch (const NormalError& error) {
throw ParseError(error.what(), m_node.begin());
}
}
BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};
#endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP
......@@ -22,6 +22,7 @@ add_library(PugsLanguageUtils
BinaryOperatorRegisterForZ.cpp
DataVariant.cpp
EmbeddedData.cpp
EmbeddedIDiscreteFunctionOperators.cpp
FunctionSymbolId.cpp
IncDecOperatorRegisterForN.cpp
IncDecOperatorRegisterForR.cpp
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment