diff --git a/src/language/ast/ASTModulesImporter.cpp b/src/language/ast/ASTModulesImporter.cpp index d52b964a8a670ca3af3d0bcd0f3223798e2502c3..fa3f479e307fd2b0c78d265879c22157dfec85b8 100644 --- a/src/language/ast/ASTModulesImporter.cpp +++ b/src/language/ast/ASTModulesImporter.cpp @@ -1,6 +1,7 @@ #include <language/ast/ASTModulesImporter.hpp> #include <language/PEGGrammar.hpp> +#include <language/utils/OperatorRepository.hpp> void ASTModulesImporter::_importModule(ASTNode& import_node) @@ -20,6 +21,7 @@ ASTModulesImporter::_importModule(ASTNode& import_node) std::cout << " * importing '" << rang::fgB::green << module_name << rang::style::reset << "' module\n"; m_module_repository.populateSymbolTable(module_name_node, m_symbol_table); + m_module_repository.registerOperators(module_name); } void @@ -37,6 +39,7 @@ ASTModulesImporter::_importAllModules(ASTNode& node) ASTModulesImporter::ASTModulesImporter(ASTNode& root_node) : m_symbol_table{*root_node.m_symbol_table} { Assert(root_node.is_root()); + OperatorRepository::instance().reset(); m_module_repository.populateMandatorySymbolTable(root_node, m_symbol_table); this->_importAllModules(root_node); diff --git a/src/language/modules/CoreModule.cpp b/src/language/modules/CoreModule.cpp index 14a9548238c91682aac53850a6cdf6f36b4d9e93..0c4f2b354adb6a059fb92cff28e196178a08effb 100644 --- a/src/language/modules/CoreModule.cpp +++ b/src/language/modules/CoreModule.cpp @@ -1,5 +1,33 @@ #include <language/modules/CoreModule.hpp> +#include <language/utils/AffectationProcessorBuilder.hpp> +#include <language/utils/AffectationRegisterForB.hpp> +#include <language/utils/AffectationRegisterForN.hpp> +#include <language/utils/AffectationRegisterForR.hpp> +#include <language/utils/AffectationRegisterForRn.hpp> +#include <language/utils/AffectationRegisterForRnxn.hpp> +#include <language/utils/AffectationRegisterForString.hpp> +#include <language/utils/AffectationRegisterForZ.hpp> + +#include <language/utils/BinaryOperatorRegisterForB.hpp> +#include <language/utils/BinaryOperatorRegisterForN.hpp> +#include <language/utils/BinaryOperatorRegisterForR.hpp> +#include <language/utils/BinaryOperatorRegisterForRn.hpp> +#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> +#include <language/utils/BinaryOperatorRegisterForString.hpp> +#include <language/utils/BinaryOperatorRegisterForZ.hpp> + +#include <language/utils/IncDecOperatorRegisterForN.hpp> +#include <language/utils/IncDecOperatorRegisterForR.hpp> +#include <language/utils/IncDecOperatorRegisterForZ.hpp> + +#include <language/utils/UnaryOperatorRegisterForB.hpp> +#include <language/utils/UnaryOperatorRegisterForN.hpp> +#include <language/utils/UnaryOperatorRegisterForR.hpp> +#include <language/utils/UnaryOperatorRegisterForRn.hpp> +#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> +#include <language/utils/UnaryOperatorRegisterForZ.hpp> + #include <language/modules/CoreModule.hpp> #include <language/modules/ModuleRepository.hpp> #include <language/utils/ASTExecutionInfo.hpp> @@ -42,3 +70,46 @@ CoreModule::CoreModule() : BuiltinModule(true) )); } + +void +CoreModule::registerOperators() const +{ + AffectationRegisterForB{}; + AffectationRegisterForN{}; + AffectationRegisterForZ{}; + AffectationRegisterForR{}; + AffectationRegisterForRn<1>{}; + AffectationRegisterForRn<2>{}; + AffectationRegisterForRn<3>{}; + AffectationRegisterForRnxn<1>{}; + AffectationRegisterForRnxn<2>{}; + AffectationRegisterForRnxn<3>{}; + AffectationRegisterForString{}; + + BinaryOperatorRegisterForB{}; + BinaryOperatorRegisterForN{}; + BinaryOperatorRegisterForZ{}; + BinaryOperatorRegisterForR{}; + BinaryOperatorRegisterForRn<1>{}; + BinaryOperatorRegisterForRn<2>{}; + BinaryOperatorRegisterForRn<3>{}; + BinaryOperatorRegisterForRnxn<1>{}; + BinaryOperatorRegisterForRnxn<2>{}; + BinaryOperatorRegisterForRnxn<3>{}; + BinaryOperatorRegisterForString{}; + + IncDecOperatorRegisterForN{}; + IncDecOperatorRegisterForR{}; + IncDecOperatorRegisterForZ{}; + + UnaryOperatorRegisterForB{}; + UnaryOperatorRegisterForN{}; + UnaryOperatorRegisterForZ{}; + UnaryOperatorRegisterForR{}; + UnaryOperatorRegisterForRn<1>{}; + UnaryOperatorRegisterForRn<2>{}; + UnaryOperatorRegisterForRn<3>{}; + UnaryOperatorRegisterForRnxn<1>{}; + UnaryOperatorRegisterForRnxn<2>{}; + UnaryOperatorRegisterForRnxn<3>{}; +} diff --git a/src/language/modules/CoreModule.hpp b/src/language/modules/CoreModule.hpp index 963719be2bce1b841fe9cee05db5b957ad53ba2b..88c673d65a7d9635aba8698c678dda7bd2081fb1 100644 --- a/src/language/modules/CoreModule.hpp +++ b/src/language/modules/CoreModule.hpp @@ -12,6 +12,8 @@ class CoreModule : public BuiltinModule return "core"; } + void registerOperators() const final; + CoreModule(); ~CoreModule() = default; }; diff --git a/src/language/modules/IModule.hpp b/src/language/modules/IModule.hpp index 48e4b15360de922a83025921d1058095cc205d89..ceb3fd3dffc7d0af7f6423e3a21e7e8aa9686ffc 100644 --- a/src/language/modules/IModule.hpp +++ b/src/language/modules/IModule.hpp @@ -25,6 +25,8 @@ class IModule virtual const NameTypeMap& getNameTypeMap() const = 0; + virtual void registerOperators() const = 0; + virtual std::string_view name() const = 0; virtual ~IModule() = default; diff --git a/src/language/modules/LinearSolverModule.cpp b/src/language/modules/LinearSolverModule.cpp index a5bb2cf90058b9232bd84e2f2bbdd71eb02ef3c7..5dafbadcfedec464b83d7d20ce798e81518e073d 100644 --- a/src/language/modules/LinearSolverModule.cpp +++ b/src/language/modules/LinearSolverModule.cpp @@ -90,3 +90,7 @@ LinearSolverModule::LinearSolverModule() )); } + +void +LinearSolverModule::registerOperators() const +{} diff --git a/src/language/modules/LinearSolverModule.hpp b/src/language/modules/LinearSolverModule.hpp index 5dc7ae64efd3ed43f554e821693f761b0a66d412..7e30c6b2f7c12cf4ede853578e6c064c7dce6292 100644 --- a/src/language/modules/LinearSolverModule.hpp +++ b/src/language/modules/LinearSolverModule.hpp @@ -12,6 +12,8 @@ class LinearSolverModule : public BuiltinModule return "linear_solver"; } + void registerOperators() const final; + LinearSolverModule(); ~LinearSolverModule() = default; }; diff --git a/src/language/modules/MathModule.cpp b/src/language/modules/MathModule.cpp index e774284b385e7e74b8cb715f73b5ebec6f81b02e..aa961e4a0ae250581fad42b644f7d670ef4fdb32 100644 --- a/src/language/modules/MathModule.cpp +++ b/src/language/modules/MathModule.cpp @@ -70,3 +70,7 @@ MathModule::MathModule() this->_addBuiltinFunction("round", std::make_shared<BuiltinFunctionEmbedder<int64_t(double)>>( [](double x) -> int64_t { return std::lround(x); })); } + +void +MathModule::registerOperators() const +{} diff --git a/src/language/modules/MathModule.hpp b/src/language/modules/MathModule.hpp index 1f001e1938691c4227722adda9724c9d88f33137..c80a74d2da37807750f25a286b9ccd9d12928850 100644 --- a/src/language/modules/MathModule.hpp +++ b/src/language/modules/MathModule.hpp @@ -12,6 +12,8 @@ class MathModule : public BuiltinModule return "math"; } + void registerOperators() const final; + MathModule(); ~MathModule() = default; diff --git a/src/language/modules/MeshModule.cpp b/src/language/modules/MeshModule.cpp index 89e678c3fb4095350842b19a9145502cd7f558a8..86c9a0c106e9fe6eece190c93365fc42a47bf864 100644 --- a/src/language/modules/MeshModule.cpp +++ b/src/language/modules/MeshModule.cpp @@ -212,3 +212,7 @@ MeshModule::MeshModule() )); } + +void +MeshModule::registerOperators() const +{} diff --git a/src/language/modules/MeshModule.hpp b/src/language/modules/MeshModule.hpp index ebb1383107ffd288001a48a5abced722ef37a4bd..3ef6f2856d052e7b7c5ab63393843145eec6f230 100644 --- a/src/language/modules/MeshModule.hpp +++ b/src/language/modules/MeshModule.hpp @@ -20,6 +20,8 @@ class MeshModule : public BuiltinModule return "mesh"; } + void registerOperators() const final; + MeshModule(); ~MeshModule() = default; diff --git a/src/language/modules/ModuleRepository.cpp b/src/language/modules/ModuleRepository.cpp index 44a437d04dc177cbfc6274420999c94fafaa22c9..d0cd8260ba6d6a3de9d335833d42002519568592 100644 --- a/src/language/modules/ModuleRepository.cpp +++ b/src/language/modules/ModuleRepository.cpp @@ -71,6 +71,14 @@ ModuleRepository::populateSymbolTable(const ASTNode& module_name_node, SymbolTab if (i_module != m_module_set.end()) { const IModule& populating_module = *i_module->second; + if (populating_module.isMandatory()) { + std::ostringstream error_message; + error_message << "module '" << rang::fgB::blue << module_name << rang::style::reset << rang::style::bold + << "' is an autoload " << rang::fgB::yellow << "mandatory" << rang::style::reset + << rang::style::bold << " module. It cannot be imported explicitly!"; + throw ParseError(error_message.str(), module_name_node.begin()); + } + this->_populateEmbedderTableT(module_name_node, module_name, populating_module.getNameBuiltinFunctionMap(), ASTNodeDataType::build<ASTNodeDataType::builtin_function_t>(), symbol_table, symbol_table.builtinFunctionEmbedderTable()); @@ -100,6 +108,8 @@ ModuleRepository::populateMandatorySymbolTable(const ASTNode& root_node, SymbolT this->_populateEmbedderTableT(root_node, module_name, i_module->getNameTypeMap(), ASTNodeDataType::build<ASTNodeDataType::type_name_id_t>(), symbol_table, symbol_table.typeEmbedderTable()); + + i_module->registerOperators(); } } } @@ -123,6 +133,17 @@ ModuleRepository::getAvailableModules() const return os.str(); } +void +ModuleRepository::registerOperators(const std::string& module_name) +{ + auto i_module = m_module_set.find(module_name); + if (i_module != m_module_set.end()) { + i_module->second->registerOperators(); + } else { + throw NormalError(std::string{"could not find module "} + module_name); + } +} + std::string ModuleRepository::getModuleInfo(const std::string& module_name) const { diff --git a/src/language/modules/ModuleRepository.hpp b/src/language/modules/ModuleRepository.hpp index c224f6491521ad9757ad5253011604e863a10503..c4c9870fcedf550c82f94374f759f94d4c9cc878 100644 --- a/src/language/modules/ModuleRepository.hpp +++ b/src/language/modules/ModuleRepository.hpp @@ -29,6 +29,7 @@ class ModuleRepository public: void populateSymbolTable(const ASTNode& module_name_node, SymbolTable& symbol_table); void populateMandatorySymbolTable(const ASTNode& root_node, SymbolTable& symbol_table); + void registerOperators(const std::string& module_name); std::string getAvailableModules() const; std::string getModuleInfo(const std::string& module_name) const; diff --git a/src/language/modules/SchemeModule.cpp b/src/language/modules/SchemeModule.cpp index 3051b6bb2265754ed8f1a84b5a1952c59325705b..4d1a866326a46c63f71b9b45b5188d7aa11746ca 100644 --- a/src/language/modules/SchemeModule.cpp +++ b/src/language/modules/SchemeModule.cpp @@ -1,6 +1,9 @@ #include <language/modules/SchemeModule.hpp> +#include <language/utils/BinaryOperatorProcessorBuilder.hpp> #include <language/utils/BuiltinFunctionEmbedder.hpp> +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/TypeDescriptor.hpp> #include <mesh/Mesh.hpp> #include <scheme/AcousticSolver.hpp> @@ -280,3 +283,81 @@ SchemeModule::SchemeModule() )); } + +void +SchemeModule::registerOperators() const +{ + OperatorRepository& repository = OperatorRepository::instance(); + + repository.addBinaryOperator<language::plus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::plus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::minus_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::minus_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::divide_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::divide_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + bool, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + int64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + uint64_t, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + double, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<1>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<2>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + TinyMatrix<3>, std::shared_ptr<const IDiscreteFunction>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyVector<3>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<1>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<2>>>()); + + repository.addBinaryOperator<language::multiply_op>( + std::make_shared<BinaryOperatorProcessorBuilder<language::multiply_op, std::shared_ptr<const IDiscreteFunction>, + std::shared_ptr<const IDiscreteFunction>, TinyMatrix<3>>>()); +} diff --git a/src/language/modules/SchemeModule.hpp b/src/language/modules/SchemeModule.hpp index 3cfbb254a8f6e2ac36d16c1a78e53f3f995f56fc..8e5ee95d8f6797e705cba0436829a565c7714e6c 100644 --- a/src/language/modules/SchemeModule.hpp +++ b/src/language/modules/SchemeModule.hpp @@ -34,6 +34,8 @@ class SchemeModule : public BuiltinModule return "scheme"; } + void registerOperators() const final; + SchemeModule(); ~SchemeModule() = default; diff --git a/src/language/modules/UtilsModule.cpp b/src/language/modules/UtilsModule.cpp index b74ddc54eaab21ebf44dd53dbb585f4d0111c3fe..44d06e5377500e743172ceea12fe9ddfcd0a7d3e 100644 --- a/src/language/modules/UtilsModule.cpp +++ b/src/language/modules/UtilsModule.cpp @@ -65,3 +65,7 @@ UtilsModule::UtilsModule() )); } + +void +UtilsModule::registerOperators() const +{} diff --git a/src/language/modules/UtilsModule.hpp b/src/language/modules/UtilsModule.hpp index f7580a0422d3d8dabbd8c931b99dbfc7601f84da..620f965eeab4e10bdb7542b9862aa1423d2f48bb 100644 --- a/src/language/modules/UtilsModule.hpp +++ b/src/language/modules/UtilsModule.hpp @@ -12,6 +12,8 @@ class UtilsModule : public BuiltinModule return "utils"; } + void registerOperators() const final; + UtilsModule(); ~UtilsModule() = default; }; diff --git a/src/language/modules/WriterModule.cpp b/src/language/modules/WriterModule.cpp index 9f2fd98f103ec101190011b7a0ed1ded288ef009..fcb3beb6186b8b6918f1f51efccabb54a657c896 100644 --- a/src/language/modules/WriterModule.cpp +++ b/src/language/modules/WriterModule.cpp @@ -100,3 +100,7 @@ WriterModule::WriterModule() )); } + +void +WriterModule::registerOperators() const +{} diff --git a/src/language/modules/WriterModule.hpp b/src/language/modules/WriterModule.hpp index 14a961d7a812b34aac1c5010ddda674749fdbfb8..adcd92530298a99f088743e7ae6583bc09faa93c 100644 --- a/src/language/modules/WriterModule.hpp +++ b/src/language/modules/WriterModule.hpp @@ -29,6 +29,8 @@ class WriterModule : public BuiltinModule return "writer"; } + void registerOperators() const final; + WriterModule(); ~WriterModule() = default; diff --git a/src/language/utils/BinaryOperatorProcessorBuilder.hpp b/src/language/utils/BinaryOperatorProcessorBuilder.hpp index b8cb572d1e79c1882ed0288d79a4b73c6c130e5c..eb07b35d028f3eaea3f59537753dfd936a185837 100644 --- a/src/language/utils/BinaryOperatorProcessorBuilder.hpp +++ b/src/language/utils/BinaryOperatorProcessorBuilder.hpp @@ -6,9 +6,13 @@ #include <language/node_processor/BinaryExpressionProcessor.hpp> #include <language/utils/ASTNodeDataTypeTraits.hpp> #include <language/utils/IBinaryOperatorProcessorBuilder.hpp> +#include <language/utils/ParseError.hpp> #include <type_traits> +template <typename DataType> +class DataHandler; + template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT> class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder { @@ -40,4 +44,113 @@ class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuil } }; +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>> + final : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr))); + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) { + const auto& a_value = std::get<A_DataT>(a); + const auto& embedded_b = std::get<EmbeddedData>(b); + + std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr))); + } else { + static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + +template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT> +struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final + : public INodeProcessor +{ + private: + ASTNode& m_node; + + PUGS_INLINE DataVariant + _eval(const DataVariant& a, const DataVariant& b) + { + if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) { + const auto& embedded_a = std::get<EmbeddedData>(a); + const auto& b_value = std::get<B_DataT>(b); + + std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr(); + + return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value))); + } else { + static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type"); + } + } + + public: + DataVariant + execute(ExecutionPolicy& exec_policy) + { + try { + return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy)); + } + catch (const NormalError& error) { + throw ParseError(error.what(), m_node.begin()); + } + } + + BinaryExpressionProcessor(ASTNode& node) : m_node{node} {} +}; + #endif // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 32a75c66f0dd73cc10c09ca2c6aecf640962c45c..80640af6e39ee7febd53562c4c89bc1d2bd9b8c3 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(PugsLanguageUtils BinaryOperatorRegisterForZ.cpp DataVariant.cpp EmbeddedData.cpp + EmbeddedIDiscreteFunctionOperators.cpp FunctionSymbolId.cpp IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForR.cpp diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db7a609b9449d27380c2030c75e938554e6a095b --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.cpp @@ -0,0 +1,696 @@ +#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp> + +#include <language/node_processor/BinaryExpressionProcessor.hpp> +#include <scheme/DiscreteFunctionP0.hpp> +#include <scheme/IDiscreteFunction.hpp> +#include <utils/Exceptions.hpp> + +template <typename T> +PUGS_INLINE std::string +name(const T&) +{ + return dataTypeName(ast_node_data_type_from<T>); +} + +template <> +PUGS_INLINE std::string +name(const IDiscreteFunction& f) +{ + return "Vh(" + dataTypeName(f.dataType()) + ")"; +} + +template <> +PUGS_INLINE std::string +name(const std::shared_ptr<const IDiscreteFunction>& f) +{ + return "Vh(" + dataTypeName(f->dataType()) + ")"; +} + +template <typename LHS_T, typename RHS_T> +PUGS_INLINE std::string +invalid_operands(const LHS_T& f, const RHS_T& g) +{ + std::ostringstream os; + os << "undefined binary operator\n"; + os << "note: incompatible operand types " << name(f) << " and " << name(g); + return os.str(); +} + +template <typename BinOperatorT, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const DiscreteFunctionT& lhs, const DiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using data_type = typename DiscreteFunctionT::data_type; + if constexpr ((std::is_same_v<language::multiply_op, BinOperatorT> and is_tiny_vector_v<data_type>) or + (std::is_same_v<language::divide_op, BinOperatorT> and not std::is_arithmetic_v<data_type>)) { + throw NormalError(invalid_operands(lhs, rhs)); + } else { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() == g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case ASTNodeDataType::vector_t: { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return innerCompositionLaw<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } + default: { + throw UnexpectedError("invalid data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("discrete functions defined on different meshes"); + } + if (f->dataType() != g->dataType()) { + throw NormalError(invalid_operands(f, g)); + } + + switch (f->mesh()->dimension()) { + case 1: { + return innerCompositionLaw<BinOperatorT, 1>(f, g); + } + case 2: { + return innerCompositionLaw<BinOperatorT, 2>(f, g); + } + case 3: { + return innerCompositionLaw<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename LeftDiscreteFunctionT, typename RightDiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const LeftDiscreteFunctionT& lhs, const RightDiscreteFunctionT& rhs) +{ + Assert(lhs.mesh() == rhs.mesh()); + using lhs_data_type = typename LeftDiscreteFunctionT::data_type; + using rhs_data_type = typename RightDiscreteFunctionT::data_type; + + static_assert(not std::is_same_v<rhs_data_type, lhs_data_type>, + "use innerCompositionLaw when data types are the same"); + + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs)); +} + +template <typename BinOperatorT, size_t Dimension, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(fh.mesh() == g->mesh()); + Assert(fh.dataType() != g->dataType()); + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + switch (g->dataType()) { + case ASTNodeDataType::double_t: { + if constexpr (not std::is_same_v<lhs_data_type, double>) { + if constexpr (not is_tiny_matrix_v<lhs_data_type>) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } else { + throw UnexpectedError("should have called innerCompositionLaw"); + } + } + case ASTNodeDataType::vector_t: { + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().dimension()) { + case 1: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<1>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 2: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<2>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case 3: { + if constexpr (not is_tiny_vector_v<lhs_data_type> and + (std::is_same_v<lhs_data_type, TinyMatrix<3>> or std::is_same_v<lhs_data_type, double>)) { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + case ASTNodeDataType::matrix_t: { + Assert(g->dataType().nbRows() == g->dataType().nbColumns()); + if constexpr (std::is_same_v<lhs_data_type, double> and std::is_same_v<language::multiply_op, BinOperatorT>) { + switch (g->dataType().nbRows()) { + case 1: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 2: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + case 3: { + auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g); + + return applyBinaryOperation<BinOperatorT>(fh, gh); + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } + } else { + throw NormalError(invalid_operands(fh, g)); + } + } + default: { + throw UnexpectedError("invalid rhs data type Vh(" + dataTypeName(g->dataType()) + ")"); + } + } +} + +template <typename BinOperatorT, size_t Dimension> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + Assert(f->mesh() == g->mesh()); + Assert(f->dataType() != g->dataType()); + + switch (f->dataType()) { + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + + return applyBinaryOperation<BinOperatorT, Dimension>(fh, g); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + default: { + throw NormalError(invalid_operands(f, g)); + } + } +} + +template <typename BinOperatorT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f, + const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->mesh() != g->mesh()) { + throw NormalError("functions defined on different meshes"); + } + + Assert(f->dataType() != g->dataType(), "should call inner composition instead"); + + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperation<BinOperatorT, 1>(f, g); + } + case 2: { + return applyBinaryOperation<BinOperatorT, 2>(f, g); + } + case 3: { + return applyBinaryOperation<BinOperatorT, 3>(f, g); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator+(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::plus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator-(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + return innerCompositionLaw<language::minus_op>(f, g); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::multiply_op>(f, g); + } else { + return applyBinaryOperation<language::multiply_op>(f, g); + } +} + +std::shared_ptr<const IDiscreteFunction> +operator/(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g) +{ + if (f->dataType() == g->dataType()) { + return innerCompositionLaw<language::divide_op>(f, g); + } else { + return applyBinaryOperation<language::divide_op>(f, g); + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const DiscreteFunctionT& f) +{ + using lhs_data_type = std::decay_t<DataType>; + using rhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (std::is_same_v<lhs_data_type, double>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else if constexpr (is_tiny_matrix_v<lhs_data_type> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f)); + } else { + throw NormalError(invalid_operands(a, f)); + } + } else { + throw NormalError(invalid_operands(a, f)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case ASTNodeDataType::vector_t: { + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().dimension()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().dimension()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } else { + throw NormalError(invalid_operands(a, f)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(a, f)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 1>(a, f); + } + case 2: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 2>(a, f); + } + case 3: { + return applyBinaryOperationWithLeftConstant<BinOperatorT, 3>(a, f); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const DiscreteFunctionT& f, const DataType& a) +{ + using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>; + using rhs_data_type = std::decay_t<DataType>; + + if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) { + if constexpr (is_tiny_matrix_v<lhs_data_type> and is_tiny_matrix_v<rhs_data_type>) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else if constexpr (std::is_same_v<lhs_data_type, double> and + (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) { + return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a)); + } else { + throw NormalError(invalid_operands(f, a)); + } + } else { + throw NormalError(invalid_operands(f, a)); + } +} + +template <typename BinOperatorT, size_t Dimension, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->dataType()) { + case ASTNodeDataType::bool_t: + case ASTNodeDataType::unsigned_int_t: + case ASTNodeDataType::int_t: + case ASTNodeDataType::double_t: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case ASTNodeDataType::matrix_t: { + Assert(f->dataType().nbRows() == f->dataType().nbColumns()); + if constexpr (is_tiny_matrix_v<DataType>) { + switch (f->dataType().nbRows()) { + case 1: { + if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 2: { + if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + case 3: { + if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } else { + throw NormalError(invalid_operands(f, a)); + } + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } else { + switch (f->dataType().nbRows()) { + case 1: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 2: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + case 3: { + auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f); + return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a); + } + default: { + throw UnexpectedError("invalid lhs data type Vh(" + dataTypeName(f->dataType()) + ")"); + } + } + } + } + default: { + throw NormalError(invalid_operands(f, a)); + } + } +} + +template <typename BinOperatorT, typename DataType> +std::shared_ptr<const IDiscreteFunction> +applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a) +{ + switch (f->mesh()->dimension()) { + case 1: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 1>(f, a); + } + case 2: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 2>(f, a); + } + case 3: { + return applyBinaryOperationWithRightConstant<BinOperatorT, 3>(f, a); + } + default: { + throw UnexpectedError("invalid mesh dimension"); + } + } +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const double& a, const std::shared_ptr<const IDiscreteFunction>& f) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(a, f); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<1>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<2>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const TinyMatrix<3>& A, const std::shared_ptr<const IDiscreteFunction>& B) +{ + return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<1>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<2>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<3>& u) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<1>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<2>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} + +std::shared_ptr<const IDiscreteFunction> +operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<3>& A) +{ + return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A); +} diff --git a/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7969828000c4440ea9aa9d36014ee18d55d643c1 --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionOperators.hpp @@ -0,0 +1,52 @@ +#ifndef EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP +#define EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP + +#include <algebra/TinyMatrix.hpp> +#include <algebra/TinyVector.hpp> + +#include <memory> + +class IDiscreteFunction; + +std::shared_ptr<const IDiscreteFunction> operator+(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator-(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator/(const std::shared_ptr<const IDiscreteFunction>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const double&, const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<1>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<2>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const TinyMatrix<3>&, + const std::shared_ptr<const IDiscreteFunction>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyVector<3>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<1>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<2>&); + +std::shared_ptr<const IDiscreteFunction> operator*(const std::shared_ptr<const IDiscreteFunction>&, + const TinyMatrix<3>&); + +#endif // EMBEDDED_I_DISCRETE_FUNCTION_OPERATORS_HPP diff --git a/src/language/utils/OperatorRepository.cpp b/src/language/utils/OperatorRepository.cpp index d34344b967f41c72ae09849a554958b8ebfcc1f3..1ab43da119b0c09fe7c2dbe9ae2348bbfbc81996 100644 --- a/src/language/utils/OperatorRepository.cpp +++ b/src/language/utils/OperatorRepository.cpp @@ -1,33 +1,5 @@ #include <language/utils/OperatorRepository.hpp> -#include <language/utils/AffectationProcessorBuilder.hpp> -#include <language/utils/AffectationRegisterForB.hpp> -#include <language/utils/AffectationRegisterForN.hpp> -#include <language/utils/AffectationRegisterForR.hpp> -#include <language/utils/AffectationRegisterForRn.hpp> -#include <language/utils/AffectationRegisterForRnxn.hpp> -#include <language/utils/AffectationRegisterForString.hpp> -#include <language/utils/AffectationRegisterForZ.hpp> - -#include <language/utils/BinaryOperatorRegisterForB.hpp> -#include <language/utils/BinaryOperatorRegisterForN.hpp> -#include <language/utils/BinaryOperatorRegisterForR.hpp> -#include <language/utils/BinaryOperatorRegisterForRn.hpp> -#include <language/utils/BinaryOperatorRegisterForRnxn.hpp> -#include <language/utils/BinaryOperatorRegisterForString.hpp> -#include <language/utils/BinaryOperatorRegisterForZ.hpp> - -#include <language/utils/IncDecOperatorRegisterForN.hpp> -#include <language/utils/IncDecOperatorRegisterForR.hpp> -#include <language/utils/IncDecOperatorRegisterForZ.hpp> - -#include <language/utils/UnaryOperatorRegisterForB.hpp> -#include <language/utils/UnaryOperatorRegisterForN.hpp> -#include <language/utils/UnaryOperatorRegisterForR.hpp> -#include <language/utils/UnaryOperatorRegisterForRn.hpp> -#include <language/utils/UnaryOperatorRegisterForRnxn.hpp> -#include <language/utils/UnaryOperatorRegisterForZ.hpp> - #include <utils/PugsAssert.hpp> OperatorRepository* OperatorRepository::m_instance = nullptr; @@ -39,7 +11,6 @@ OperatorRepository::reset() m_binary_operator_builder_list.clear(); m_inc_dec_operator_builder_list.clear(); m_unary_operator_builder_list.clear(); - this->_initialize(); } void @@ -47,7 +18,6 @@ OperatorRepository::create() { Assert(m_instance == nullptr, "AffectationRepository was already created"); m_instance = new OperatorRepository; - m_instance->_initialize(); } void @@ -57,46 +27,3 @@ OperatorRepository::destroy() delete m_instance; m_instance = nullptr; } - -void -OperatorRepository::_initialize() -{ - AffectationRegisterForB{}; - AffectationRegisterForN{}; - AffectationRegisterForZ{}; - AffectationRegisterForR{}; - AffectationRegisterForRn<1>{}; - AffectationRegisterForRn<2>{}; - AffectationRegisterForRn<3>{}; - AffectationRegisterForRnxn<1>{}; - AffectationRegisterForRnxn<2>{}; - AffectationRegisterForRnxn<3>{}; - AffectationRegisterForString{}; - - BinaryOperatorRegisterForB{}; - BinaryOperatorRegisterForN{}; - BinaryOperatorRegisterForZ{}; - BinaryOperatorRegisterForR{}; - BinaryOperatorRegisterForRn<1>{}; - BinaryOperatorRegisterForRn<2>{}; - BinaryOperatorRegisterForRn<3>{}; - BinaryOperatorRegisterForRnxn<1>{}; - BinaryOperatorRegisterForRnxn<2>{}; - BinaryOperatorRegisterForRnxn<3>{}; - BinaryOperatorRegisterForString{}; - - IncDecOperatorRegisterForN{}; - IncDecOperatorRegisterForR{}; - IncDecOperatorRegisterForZ{}; - - UnaryOperatorRegisterForB{}; - UnaryOperatorRegisterForN{}; - UnaryOperatorRegisterForZ{}; - UnaryOperatorRegisterForR{}; - UnaryOperatorRegisterForRn<1>{}; - UnaryOperatorRegisterForRn<2>{}; - UnaryOperatorRegisterForRn<3>{}; - UnaryOperatorRegisterForRnxn<1>{}; - UnaryOperatorRegisterForRnxn<2>{}; - UnaryOperatorRegisterForRnxn<3>{}; -} diff --git a/src/language/utils/OperatorRepository.hpp b/src/language/utils/OperatorRepository.hpp index 6ac61e6e245f8f545e4794ffb4bd4703d4806b82..cf4b322e7f9da12311a0e1b1dcd47ee0c25e7936 100644 --- a/src/language/utils/OperatorRepository.hpp +++ b/src/language/utils/OperatorRepository.hpp @@ -58,8 +58,6 @@ class OperatorRepository std::unordered_map<std::string, Descriptor<IUnaryOperatorProcessorBuilder>> m_unary_operator_builder_list; - void _initialize(); - public: void reset(); diff --git a/src/scheme/DiscreteFunctionP0.hpp b/src/scheme/DiscreteFunctionP0.hpp index c51e8e09dbb78dc6f40f4fbb18eb737ddb7e1a30..53037ea72b28813cd2b2e619574f8faefa7484f0 100644 --- a/src/scheme/DiscreteFunctionP0.hpp +++ b/src/scheme/DiscreteFunctionP0.hpp @@ -13,9 +13,11 @@ template <size_t Dimension, typename DataType> class DiscreteFunctionP0 : public IDiscreteFunction { - private: - using MeshType = Mesh<Connectivity<Dimension>>; + public: + using data_type = DataType; + using MeshType = Mesh<Connectivity<Dimension>>; + private: std::shared_ptr<const MeshType> m_mesh; CellValue<DataType> m_cell_values; @@ -53,6 +55,72 @@ class DiscreteFunctionP0 : public IDiscreteFunction return m_cell_values[cell_id]; } + friend DiscreteFunctionP0 + operator+(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 sum(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { sum[cell_id] = f[cell_id] + g[cell_id]; }); + return sum; + } + + friend DiscreteFunctionP0 + operator-(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 difference(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { difference[cell_id] = f[cell_id] - g[cell_id]; }); + return difference; + } + + friend DiscreteFunctionP0 + operator*(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + template <typename DataType2T> + friend DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> + operator*(const DiscreteFunctionP0<Dimension, DataType2T>& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, decltype(DataType2T{} * DataType{})> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * g[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator*(const double& a, const DiscreteFunctionP0& f) + { + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = a * f[cell_id]; }); + return product; + } + + friend DiscreteFunctionP0 + operator/(const DiscreteFunctionP0& f, const DiscreteFunctionP0& g) + { + Assert(f.mesh() == g.mesh(), "functions are nor defined on the same mesh"); + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0 ratio(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = f[cell_id] / g[cell_id]; }); + return ratio; + } + DiscreteFunctionP0(const std::shared_ptr<const MeshType>& mesh, const FunctionSymbolId& function_id) : m_mesh(mesh) { using MeshDataType = MeshData<Dimension>; @@ -80,4 +148,64 @@ class DiscreteFunctionP0 : public IDiscreteFunction ~DiscreteFunctionP0() = default; }; +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const TinyMatrix<ValueDimension>& A, const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = A * f[cell_id]; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyMatrix<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyMatrix<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + +template <size_t Dimension, size_t ValueDimension> +DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> +operator*(const DiscreteFunctionP0<Dimension, double>& f, const TinyVector<ValueDimension>& A) +{ + using MeshType = typename DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>>::MeshType; + std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(f.mesh()); + DiscreteFunctionP0<Dimension, TinyVector<ValueDimension>> product(mesh); + parallel_for( + mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { product[cell_id] = f[cell_id] * A; }); + return product; +} + #endif // DISCRETE_FUNCTION_P0_HPP diff --git a/tests/test_ASTNodeAffectationExpressionBuilder.cpp b/tests/test_ASTNodeAffectationExpressionBuilder.cpp index b155ab61473f320be06b499e70cc4a32a269ce6f..85c443ebcc807d87ba197bc8f534763fcf46beba 100644 --- a/tests/test_ASTNodeAffectationExpressionBuilder.cpp +++ b/tests/test_ASTNodeAffectationExpressionBuilder.cpp @@ -2,6 +2,7 @@ #include <catch2/matchers/catch_matchers_all.hpp> #include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTModulesImporter.hpp> #include <language/ast/ASTNodeAffectationExpressionBuilder.hpp> #include <language/ast/ASTNodeDataTypeBuilder.hpp> #include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp> @@ -13,6 +14,7 @@ #include <language/utils/ASTPrinter.hpp> #include <language/utils/BasicAffectationRegistrerFor.hpp> #include <language/utils/EmbeddedData.hpp> +#include <language/utils/OperatorRepository.hpp> #include <language/utils/TypeDescriptor.hpp> #include <utils/Demangle.hpp> #include <utils/Exceptions.hpp> @@ -27,9 +29,9 @@ \ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ \ - BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ - \ auto ast = ASTBuilder::build(input); \ + ASTModulesImporter{*ast}; \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ @@ -43,8 +45,6 @@ ast_output << '\n' << ASTPrinter{*ast, ASTPrinter::Format::raw, {ASTPrinter::Info::exec_type}}; \ \ REQUIRE(ast_output.str() == expected_output); \ - \ - OperatorRepository::instance().reset(); \ } template <> @@ -58,10 +58,10 @@ const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const dou static_assert(std::is_same_v<std::decay_t<decltype(expected_output)>, std::string_view> or \ std::is_same_v<std::decay_t<decltype(expected_output)>, std::string>); \ \ - BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ - \ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ + ASTModulesImporter{*ast}; \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ \ SymbolTable& symbol_table = *ast->m_symbol_table; \ auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ @@ -111,6 +111,8 @@ const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const dou \ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ + OperatorRepository::instance().reset(); \ + ASTModulesImporter{*ast}; \ \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ @@ -127,10 +129,12 @@ const auto builtin_data_type = ast_node_data_type_from<std::shared_ptr<const dou static_assert(std::is_same_v<std::decay_t<decltype(expected_error)>, std::string_view> or \ std::is_same_v<std::decay_t<decltype(expected_error)>, std::string>); \ \ - BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ - \ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ + OperatorRepository::instance().reset(); \ + ASTModulesImporter{*ast}; \ + \ + BasicAffectationRegisterFor<EmbeddedData>{ASTNodeDataType::build<ASTNodeDataType::type_id_t>("builtin_t")}; \ \ SymbolTable& symbol_table = *ast->m_symbol_table; \ auto [i_symbol, success] = symbol_table.add(builtin_data_type.nameOfTypeId(), ast->begin()); \ @@ -1795,6 +1799,8 @@ let (x,y,z):R*R*R, (x,y) = (2,3); TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; auto ast = ASTBuilder::build(input); + OperatorRepository::instance().reset(); + ASTModulesImporter{*ast}; ASTSymbolTableBuilder{*ast}; REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, @@ -1810,6 +1816,8 @@ let x:R, (x,y) = (2,3); TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; auto ast = ASTBuilder::build(input); + OperatorRepository::instance().reset(); + ASTModulesImporter{*ast}; ASTSymbolTableBuilder{*ast}; REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, @@ -1825,6 +1833,8 @@ let x:R, y = 3; TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; auto ast = ASTBuilder::build(input); + OperatorRepository::instance().reset(); + ASTModulesImporter{*ast}; ASTSymbolTableBuilder{*ast}; REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, std::string{"invalid identifier, expecting 'x'"}); @@ -1838,6 +1848,8 @@ let (x,y):R, (y,x) = (3,2); TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; auto ast = ASTBuilder::build(input); + OperatorRepository::instance().reset(); + ASTModulesImporter{*ast}; ASTSymbolTableBuilder{*ast}; REQUIRE_THROWS_WITH(ASTSymbolInitializationChecker{*ast}, std::string{"invalid identifier, expecting 'x'"}); diff --git a/tests/test_AffectationProcessor.cpp b/tests/test_AffectationProcessor.cpp index f64af713bdd5e45ddec46b50869b7fb035b8a0f3..ec130a135320bf1ceba5ee1ff532b57437296d2e 100644 --- a/tests/test_AffectationProcessor.cpp +++ b/tests/test_AffectationProcessor.cpp @@ -2,6 +2,7 @@ #include <catch2/matchers/catch_matchers_all.hpp> #include <language/ast/ASTBuilder.hpp> +#include <language/ast/ASTModulesImporter.hpp> #include <language/ast/ASTNodeAffectationExpressionBuilder.hpp> #include <language/ast/ASTNodeDataTypeBuilder.hpp> #include <language/ast/ASTNodeDeclarationToAffectationConverter.hpp> @@ -20,6 +21,8 @@ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ \ + ASTModulesImporter{*ast}; \ + \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ \ @@ -48,6 +51,8 @@ TAO_PEGTL_NAMESPACE::string_input input{data, "test.pgs"}; \ auto ast = ASTBuilder::build(input); \ \ + ASTModulesImporter{*ast}; \ + \ ASTSymbolTableBuilder{*ast}; \ ASTNodeDataTypeBuilder{*ast}; \ \