From 3f412d77b84c4bd58637ee0ca3e55fe083ee753d Mon Sep 17 00:00:00 2001 From: Stephane Del Pino <stephane.delpino44@gmail.com> Date: Wed, 16 Oct 2019 12:54:47 +0200 Subject: [PATCH] Build CFunctionEbeddertable - it is populated during module import - the math module now provide both `cos` and `sin` functions --- src/language/ASTBuilder.cpp | 3 +- src/language/ASTModulesImporter.cpp | 26 +++++++----- .../ASTNodeCFunctionExpressionBuilder.cpp | 7 +++- src/language/CFunctionEmbedder.hpp | 6 ++- src/language/CFunctionEmbedderTable.hpp | 40 +++++++++++++++++++ src/language/CMathModule.cpp | 19 ++++++++- src/language/CMathModule.hpp | 14 +++---- src/language/SymbolTable.hpp | 28 +++++++++++-- .../node_processor/CFunctionProcessor.hpp | 16 ++++---- tests/test_SymbolTable.cpp | 6 +-- 10 files changed, 126 insertions(+), 39 deletions(-) create mode 100644 src/language/CFunctionEmbedderTable.hpp diff --git a/src/language/ASTBuilder.cpp b/src/language/ASTBuilder.cpp index be1a9d966..1ba6bfcec 100644 --- a/src/language/ASTBuilder.cpp +++ b/src/language/ASTBuilder.cpp @@ -284,8 +284,7 @@ ASTBuilder::build(InputT& input) std::unique_ptr root_node = parse_tree::parse<language::grammar, ASTNode, selector, nothing, language::errors>(input); // build initial symbol tables - std::shared_ptr function_table = std::make_shared<FunctionTable>(); - std::shared_ptr symbol_table = std::make_shared<SymbolTable>(function_table); + std::shared_ptr symbol_table = std::make_shared<SymbolTable>(); root_node->m_symbol_table = symbol_table; diff --git a/src/language/ASTModulesImporter.cpp b/src/language/ASTModulesImporter.cpp index 257e0d916..1a2e8765c 100644 --- a/src/language/ASTModulesImporter.cpp +++ b/src/language/ASTModulesImporter.cpp @@ -5,8 +5,11 @@ #include <PEGGrammar.hpp> +#include <CFunctionEmbedder.hpp> #include <CMathModule.hpp> +#include <memory> + void ASTModulesImporter::_importModule(ASTNode& import_node) { @@ -21,20 +24,23 @@ ASTModulesImporter::_importModule(ASTNode& import_node) if (module_name == "math") { CMathModule math_module; - std::string symbol_name{"sin"}; + CFunctionEmbedderTable& c_function_embedder_table = m_symbol_table.cFunctionEbedderTable(); - auto [i_symbol, success] = m_symbol_table.add(symbol_name, import_node.begin()); + for (auto [symbol_name, c_function] : math_module.getNameCFunctionsMap()) { + auto [i_symbol, success] = m_symbol_table.add(symbol_name, import_node.begin()); - if (not success) { - std::ostringstream error_message; - error_message << "cannot add symbol '" << symbol_name << "' it is already defined"; - throw parse_error(error_message.str(), import_node.begin()); - } + if (not success) { + std::ostringstream error_message; + error_message << "cannot add symbol '" << symbol_name << "' it is already defined"; + throw parse_error(error_message.str(), import_node.begin()); + } - i_symbol->attributes().setDataType(ASTNodeDataType::c_function_t); - i_symbol->attributes().setIsInitialized(); - i_symbol->attributes().value() = static_cast<uint64_t>(0); + i_symbol->attributes().setDataType(ASTNodeDataType::c_function_t); + i_symbol->attributes().setIsInitialized(); + i_symbol->attributes().value() = c_function_embedder_table.size(); + c_function_embedder_table.add(c_function); + } } else { throw parse_error(std::string{"could not find module "} + module_name, std::vector{module_name_node.begin()}); } diff --git a/src/language/ASTNodeCFunctionExpressionBuilder.cpp b/src/language/ASTNodeCFunctionExpressionBuilder.cpp index 6fc40f2f9..d42420b70 100644 --- a/src/language/ASTNodeCFunctionExpressionBuilder.cpp +++ b/src/language/ASTNodeCFunctionExpressionBuilder.cpp @@ -67,8 +67,13 @@ ASTNodeCFunctionExpressionBuilder::ASTNodeCFunctionExpressionBuilder(ASTNode& no this->_buildArgumentProcessors(node, *c_function_processor); + uint64_t c_function_id = std::get<uint64_t>(i_function_symbol->attributes().value()); + + CFunctionEmbedderTable& c_function_embedder_table = node.m_symbol_table->cFunctionEbedderTable(); + c_function_processor->setFunctionExpressionProcessor( - std::make_unique<CFunctionExpressionProcessor<double, double>>(node, c_function_processor->argumentValues())); + std::make_unique<CFunctionExpressionProcessor<double, double>>(node, c_function_embedder_table[c_function_id], + c_function_processor->argumentValues())); ASTNodeDataType c_function_return_type = ASTNodeDataType::double_t; diff --git a/src/language/CFunctionEmbedder.hpp b/src/language/CFunctionEmbedder.hpp index 37852636b..e5208d901 100644 --- a/src/language/CFunctionEmbedder.hpp +++ b/src/language/CFunctionEmbedder.hpp @@ -1,16 +1,20 @@ #ifndef CFUNCTION_EMBEDDER_HPP #define CFUNCTION_EMBEDDER_HPP +#include <PugsAssert.hpp> #include <PugsMacros.hpp> +#include <ASTNodeDataVariant.hpp> + #include <cmath> #include <functional> #include <iostream> #include <tuple> #include <vector> -struct ICFunctionEmbedder +class ICFunctionEmbedder { + public: virtual void apply(const std::vector<ASTNodeDataVariant>& x, double& f_x) = 0; virtual ~ICFunctionEmbedder() = default; diff --git a/src/language/CFunctionEmbedderTable.hpp b/src/language/CFunctionEmbedderTable.hpp new file mode 100644 index 000000000..e2a7c5be6 --- /dev/null +++ b/src/language/CFunctionEmbedderTable.hpp @@ -0,0 +1,40 @@ +#ifndef CFUNCTION_EMBEDDER_TABLE_HPP +#define CFUNCTION_EMBEDDER_TABLE_HPP + +#include <PugsAssert.hpp> + +#include <memory> +#include <vector> + +class ICFunctionEmbedder; +class CFunctionEmbedderTable +{ + private: + std::vector<std::shared_ptr<ICFunctionEmbedder>> m_c_function_embedder_list; + + public: + PUGS_INLINE + size_t + size() const + { + return m_c_function_embedder_list.size(); + } + + PUGS_INLINE + const std::shared_ptr<ICFunctionEmbedder>& operator[](size_t function_id) const + { + Assert(function_id < m_c_function_embedder_list.size()); + return m_c_function_embedder_list[function_id]; + } + + void + add(std::shared_ptr<ICFunctionEmbedder> c_function_embedder) + { + m_c_function_embedder_list.push_back(c_function_embedder); + } + + CFunctionEmbedderTable() = default; + ~CFunctionEmbedderTable() = default; +}; + +#endif // CFUNCTION_EMBEDDER_TABLE_HPP diff --git a/src/language/CMathModule.cpp b/src/language/CMathModule.cpp index 2508026ae..ec37707b9 100644 --- a/src/language/CMathModule.cpp +++ b/src/language/CMathModule.cpp @@ -1,9 +1,24 @@ #include <CMathModule.hpp> +#include <CFunctionEmbedder.hpp> + #include <iostream> +void +CMathModule::_addFunction(const std::string& name, std::shared_ptr<ICFunctionEmbedder> c_function_embedder) +{ + auto [i_function, success] = m_name_cfunction_map.insert(std::make_pair(name, c_function_embedder)); + if (not success) { + std::cerr << "function " << name << " cannot be add!\n"; + std::exit(1); + } +} + CMathModule::CMathModule() { - std::cerr << __FILE__ << ':' << __LINE__ << ": CMathModule construction NIY\n"; - std::exit(1); + this->_addFunction("sin", std::make_shared<CFunctionEmbedder<double, double>>( + std::function<double(double)>{[](double x) -> double { return std::sin(x); }})); + + this->_addFunction("cos", std::make_shared<CFunctionEmbedder<double, double>>( + std::function<double(double)>{[](double x) -> double { return std::cos(x); }})); } diff --git a/src/language/CMathModule.hpp b/src/language/CMathModule.hpp index d7d929983..e9b56fbd7 100644 --- a/src/language/CMathModule.hpp +++ b/src/language/CMathModule.hpp @@ -3,24 +3,24 @@ #include <PugsMacros.hpp> -#include <map> #include <memory> #include <string> +#include <unordered_map> #include <vector> class ICFunctionEmbedder; class CMathModule { private: - std::map<std::string, uint64_t> m_name_fid_map; - std::vector<std::shared_ptr<ICFunctionEmbedder>> m_cfunction_list; + std::unordered_map<std::string, std::shared_ptr<ICFunctionEmbedder>> m_name_cfunction_map; + + void _addFunction(const std::string& name, std::shared_ptr<ICFunctionEmbedder> c_function_embedder); public: - PUGS_INLINE - auto - find(const std::string& name) const + const auto& + getNameCFunctionsMap() const { - return m_name_fid_map.find(name); + return m_name_cfunction_map; } CMathModule(); diff --git a/src/language/SymbolTable.hpp b/src/language/SymbolTable.hpp index c8430916b..1039af0cd 100644 --- a/src/language/SymbolTable.hpp +++ b/src/language/SymbolTable.hpp @@ -10,6 +10,7 @@ #include <iostream> +#include <CFunctionEmbedderTable.hpp> #include <FunctionTable.hpp> class SymbolTable @@ -134,20 +135,37 @@ class SymbolTable std::vector<Symbol> m_symbol_list; std::shared_ptr<SymbolTable> m_parent_table; std::shared_ptr<FunctionTable> m_function_table; + std::shared_ptr<CFunctionEmbedderTable> m_c_function_embedder_table; public: const FunctionTable& functionTable() const { + Assert(m_function_table); return *m_function_table; } FunctionTable& functionTable() { + Assert(m_function_table); return *m_function_table; } + const CFunctionEmbedderTable& + cFunctionEbedderTable() const + { + Assert(m_c_function_embedder_table); + return *m_c_function_embedder_table; + } + + CFunctionEmbedderTable& + cFunctionEbedderTable() + { + Assert(m_c_function_embedder_table); + return *m_c_function_embedder_table; + } + friend std::ostream& operator<<(std::ostream& os, const SymbolTable& symbol_table) { @@ -195,13 +213,17 @@ class SymbolTable } SymbolTable(const std::shared_ptr<SymbolTable>& parent_table) - : m_parent_table(parent_table), m_function_table(parent_table->m_function_table) + : m_parent_table(parent_table), + m_function_table(parent_table->m_function_table), + m_c_function_embedder_table(parent_table->m_c_function_embedder_table) { ; } - SymbolTable(const std::shared_ptr<FunctionTable>& function_table) - : m_parent_table(nullptr), m_function_table(function_table) + SymbolTable() + : m_parent_table(nullptr), + m_function_table(std::make_shared<FunctionTable>()), + m_c_function_embedder_table(std::make_shared<CFunctionEmbedderTable>()) { ; } diff --git a/src/language/node_processor/CFunctionProcessor.hpp b/src/language/node_processor/CFunctionProcessor.hpp index 15c56da81..0b587afed 100644 --- a/src/language/node_processor/CFunctionProcessor.hpp +++ b/src/language/node_processor/CFunctionProcessor.hpp @@ -38,16 +38,15 @@ class CFunctionExpressionProcessor final : public INodeProcessor private: ASTNode& m_node; + std::shared_ptr<ICFunctionEmbedder> m_embedded_c_function; std::vector<ASTNodeDataVariant>& m_argument_values; - std::unique_ptr<ICFunctionEmbedder> m_embedded_function; - public: void execute(ExecUntilBreakOrContinue&) { ReturnType result; - m_embedded_function->apply(m_argument_values, result); + m_embedded_c_function->apply(m_argument_values, result); if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) { m_node.m_value = result; } else { @@ -55,12 +54,11 @@ class CFunctionExpressionProcessor final : public INodeProcessor } } - CFunctionExpressionProcessor(ASTNode& node, std::vector<ASTNodeDataVariant>& argument_values) - : m_node{node}, m_argument_values{argument_values} - { - m_embedded_function = std::make_unique<CFunctionEmbedder<double, double>>( - std::function{[](double x) -> double { return std::sin(x); }}); - } + CFunctionExpressionProcessor(ASTNode& node, + std::shared_ptr<ICFunctionEmbedder> embedded_c_function, + std::vector<ASTNodeDataVariant>& argument_values) + : m_node{node}, m_embedded_c_function(embedded_c_function), m_argument_values{argument_values} + {} }; class CFunctionProcessor : public INodeProcessor diff --git a/tests/test_SymbolTable.cpp b/tests/test_SymbolTable.cpp index b3f466a76..1bf8a63b8 100644 --- a/tests/test_SymbolTable.cpp +++ b/tests/test_SymbolTable.cpp @@ -9,8 +9,7 @@ TEST_CASE("SymbolTable", "[language]") { SECTION("Simple Symbol Table") { - std::shared_ptr function_table = std::make_shared<FunctionTable>(); - std::shared_ptr root_st = std::make_shared<SymbolTable>(function_table); + std::shared_ptr root_st = std::make_shared<SymbolTable>(); using namespace TAO_PEGTL_NAMESPACE; position begin_position{internal::iterator{"fixture"}, "fixture"}; @@ -98,8 +97,7 @@ TEST_CASE("SymbolTable", "[language]") SECTION("Hierarchy Symbol Table") { - std::shared_ptr function_table = std::make_shared<FunctionTable>(); - std::shared_ptr root_st = std::make_shared<SymbolTable>(function_table); + std::shared_ptr root_st = std::make_shared<SymbolTable>(); using namespace TAO_PEGTL_NAMESPACE; position begin_position{internal::iterator{"fixture"}, "fixture"}; -- GitLab