From 60a0277ad7af545e7196ae60b74b969cab8a987a Mon Sep 17 00:00:00 2001
From: Stephane Del Pino <stephane.delpino44@gmail.com>
Date: Fri, 6 Dec 2019 17:08:26 +0100
Subject: [PATCH] Fix function parameter initialization.

Context management for parameters of functions was incorrect. To fix it we store
a context identifier which allows to identify precisely parameters location (in
memory).

Thanks to this fix, the following code produces the correct output (as it was
the case before context introduction)
``
let f: R -> R, x -> x+3;
let g: R -> R, x -> 2*f(x);

R x = g(2);
``
---
 .../ASTNodeFunctionExpressionBuilder.cpp      |  2 +-
 src/language/SymbolTable.hpp                  | 33 +++++++--
 .../node_processor/BreakProcessor.hpp         |  2 +-
 .../node_processor/CFunctionProcessor.hpp     |  9 ++-
 .../node_processor/ContinueProcessor.hpp      |  2 +-
 .../node_processor/ExecutionPolicy.hpp        | 74 ++++++++++++++++---
 .../node_processor/FunctionProcessor.hpp      | 17 +++--
 .../node_processor/LocalNameProcessor.hpp     |  6 +-
 tests/test_ExecutionPolicy.cpp                |  6 +-
 9 files changed, 114 insertions(+), 37 deletions(-)

diff --git a/src/language/ASTNodeFunctionExpressionBuilder.cpp b/src/language/ASTNodeFunctionExpressionBuilder.cpp
index 3d8000370..358892b66 100644
--- a/src/language/ASTNodeFunctionExpressionBuilder.cpp
+++ b/src/language/ASTNodeFunctionExpressionBuilder.cpp
@@ -218,7 +218,7 @@ ASTNodeFunctionExpressionBuilder::ASTNodeFunctionExpressionBuilder(ASTNode& node
   Assert(function_expression.m_symbol_table->hasContext());
   const SymbolTable::Context& context = function_expression.m_symbol_table->context();
 
-  std::unique_ptr function_processor = std::make_unique<FunctionProcessor>(context.size());
+  std::unique_ptr function_processor = std::make_unique<FunctionProcessor>(context);
 
   this->_buildArgumentProcessors(function_descriptor, node, *function_processor);
 
diff --git a/src/language/SymbolTable.hpp b/src/language/SymbolTable.hpp
index c625d8a0c..c8db0705e 100644
--- a/src/language/SymbolTable.hpp
+++ b/src/language/SymbolTable.hpp
@@ -20,7 +20,7 @@ class SymbolTable
   {
    private:
     TAO_PEGTL_NAMESPACE::position m_position;
-    bool m_has_local_context;
+    int32_t m_context_id;
 
     bool m_is_initialized{false};
 
@@ -31,7 +31,13 @@ class SymbolTable
     bool
     hasLocalContext() const
     {
-      return m_has_local_context;
+      return m_context_id != -1;
+    }
+
+    const int32_t&
+    contextId() const
+    {
+      return m_context_id;
     }
 
     auto&
@@ -96,8 +102,8 @@ class SymbolTable
       return os;
     }
 
-    Attributes(const TAO_PEGTL_NAMESPACE::position& position, const bool& has_local_context)
-      : m_position{position}, m_has_local_context{has_local_context}
+    Attributes(const TAO_PEGTL_NAMESPACE::position& position, const int32_t& context_id)
+      : m_position{position}, m_context_id{context_id}
     {}
 
     Attributes(const Attributes&) = default;
@@ -144,22 +150,34 @@ class SymbolTable
   class Context
   {
    private:
+    inline static int32_t next_context_id{0};
+
+    int32_t m_id;
     size_t m_size{0};
 
    public:
+    PUGS_INLINE
+    int32_t
+    id() const
+    {
+      return m_id;
+    }
+
+    PUGS_INLINE
     size_t
     size() const
     {
       return m_size;
     }
 
+    PUGS_INLINE
     size_t
     getNextSymbolId()
     {
       return m_size++;
     }
 
-    Context() = default;
+    Context() : m_id{next_context_id++} {}
 
     Context& operator=(Context&&) = default;
     Context& operator=(const Context&) = default;
@@ -264,8 +282,11 @@ class SymbolTable
         return std::make_pair(i_stored_symbol, false);
       }
     }
+
+    int32_t context_id = this->hasContext() ? m_context->id() : -1;
+
     auto i_symbol =
-      m_symbol_list.emplace(m_symbol_list.end(), Symbol{symbol_name, Attributes{symbol_position, this->hasContext()}});
+      m_symbol_list.emplace(m_symbol_list.end(), Symbol{symbol_name, Attributes{symbol_position, context_id}});
 
     if (this->hasContext()) {
       i_symbol->attributes().value() = m_context->getNextSymbolId();
diff --git a/src/language/node_processor/BreakProcessor.hpp b/src/language/node_processor/BreakProcessor.hpp
index edf218d48..039729e52 100644
--- a/src/language/node_processor/BreakProcessor.hpp
+++ b/src/language/node_processor/BreakProcessor.hpp
@@ -9,7 +9,7 @@ class BreakProcessor final : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    exec_policy = ExecutionPolicy(ExecutionPolicy::JumpType::break_jump, exec_policy.sharedContext());
+    exec_policy = ExecutionPolicy(exec_policy, ExecutionPolicy::JumpType::break_jump);
     return {};
   }
 
diff --git a/src/language/node_processor/CFunctionProcessor.hpp b/src/language/node_processor/CFunctionProcessor.hpp
index c7bdbad51..5c5f3c82b 100644
--- a/src/language/node_processor/CFunctionProcessor.hpp
+++ b/src/language/node_processor/CFunctionProcessor.hpp
@@ -36,7 +36,7 @@ class CFunctionExpressionProcessor final : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    return m_embedded_c_function->apply(exec_policy.context());
+    return m_embedded_c_function->apply(exec_policy.currentContext().values());
   }
 
   CFunctionExpressionProcessor(std::shared_ptr<ICFunctionEmbedder> embedded_c_function)
@@ -67,9 +67,10 @@ class CFunctionProcessor : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    ExecutionPolicy context_exec_policy{exec_policy.jumpType(),
-                                        std::make_shared<ExecutionPolicy::Context>(m_argument_processors.size())};
-    auto& argument_values = context_exec_policy.context();
+    ExecutionPolicy context_exec_policy{exec_policy,
+                                        ExecutionPolicy::Context{-1, std::make_shared<ExecutionPolicy::Context::Values>(
+                                                                       m_argument_processors.size())}};
+    auto& argument_values = context_exec_policy.currentContext();
 
     for (size_t i = 0; i < argument_values.size(); ++i) {
       argument_values[i] = m_argument_processors[i]->execute(context_exec_policy);
diff --git a/src/language/node_processor/ContinueProcessor.hpp b/src/language/node_processor/ContinueProcessor.hpp
index 0b100e248..002f1e643 100644
--- a/src/language/node_processor/ContinueProcessor.hpp
+++ b/src/language/node_processor/ContinueProcessor.hpp
@@ -9,7 +9,7 @@ class ContinueProcessor final : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    exec_policy = ExecutionPolicy(ExecutionPolicy::JumpType::continue_jump, exec_policy.sharedContext());
+    exec_policy = ExecutionPolicy(exec_policy, ExecutionPolicy::JumpType::continue_jump);
     return {};
   }
 
diff --git a/src/language/node_processor/ExecutionPolicy.hpp b/src/language/node_processor/ExecutionPolicy.hpp
index 0d5349ee0..df0eb9727 100644
--- a/src/language/node_processor/ExecutionPolicy.hpp
+++ b/src/language/node_processor/ExecutionPolicy.hpp
@@ -18,14 +18,55 @@ class ExecutionPolicy
     continue_jump
   };
 
-  using Context       = std::vector<DataVariant>;
-  using SharedContext = std::shared_ptr<Context>;
+  class Context
+  {
+   public:
+    using Values       = std::vector<DataVariant>;
+    using SharedValues = std::shared_ptr<Values>;
+
+   private:
+    int32_t m_id;
+    SharedValues m_shared_values;
+
+   public:
+    auto
+    size() const
+    {
+      return m_shared_values->size();
+    }
+
+    DataVariant& operator[](size_t i)
+    {
+      return (*m_shared_values)[i];
+    }
+
+    const DataVariant& operator[](size_t i) const
+    {
+      return (*m_shared_values)[i];
+    }
+
+    const Values&
+    values() const
+    {
+      return *m_shared_values;
+    }
+
+    int32_t
+    id() const
+    {
+      return m_id;
+    }
+
+    Context(int32_t id, const SharedValues& shared_values) : m_id{id}, m_shared_values{shared_values} {}
+
+    Context(const Context&) = default;
+  };
 
  private:
   JumpType m_jump_type;
   bool m_exec;
 
-  SharedContext m_shared_context;
+  std::vector<Context> m_context_list;
 
  public:
   PUGS_INLINE
@@ -43,25 +84,36 @@ class ExecutionPolicy
   }
 
   Context&
-  context()
+  contextOfId(const int32_t& context_id)
   {
-    Assert(m_shared_context);
-    return *m_shared_context;
+    for (auto i_context = m_context_list.rbegin(); i_context != m_context_list.rend(); ++i_context) {
+      if (i_context->id() == context_id) {
+        return *i_context;
+      }
+    }
+    throw std::invalid_argument{"unable to find context"};
   }
 
-  SharedContext
-  sharedContext() const
+  Context&
+  currentContext()
   {
-    return m_shared_context;
+    return m_context_list.back();
   }
 
   ExecutionPolicy& operator=(const ExecutionPolicy&) = delete;
   ExecutionPolicy& operator=(ExecutionPolicy&&) = default;
 
+  explicit ExecutionPolicy(const ExecutionPolicy& parent_policy) = default;
+
   ExecutionPolicy() : m_jump_type{JumpType::no_jump}, m_exec{true} {}
 
-  ExecutionPolicy(const JumpType& jump_type, const SharedContext& shared_context)
-    : m_jump_type{jump_type}, m_exec{jump_type == JumpType::no_jump}, m_shared_context{shared_context}
+  ExecutionPolicy(const ExecutionPolicy& parent_policy, const Context& context) : ExecutionPolicy{parent_policy}
+  {
+    m_context_list.push_back(context);
+  }
+
+  ExecutionPolicy(const ExecutionPolicy& parent_policy, const JumpType& jump_type)
+    : m_jump_type{jump_type}, m_exec{jump_type == JumpType::no_jump}, m_context_list{parent_policy.m_context_list}
   {
     ;
   }
diff --git a/src/language/node_processor/FunctionProcessor.hpp b/src/language/node_processor/FunctionProcessor.hpp
index 572e25d29..d584f77e4 100644
--- a/src/language/node_processor/FunctionProcessor.hpp
+++ b/src/language/node_processor/FunctionProcessor.hpp
@@ -19,15 +19,13 @@ class FunctionArgumentProcessor final : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    ;
-
     if constexpr (std::is_same_v<ExpectedValueType, ProvidedValueType>) {
-      exec_policy.context()[m_symbol_id] = m_provided_value_node.execute(exec_policy);
+      exec_policy.currentContext()[m_symbol_id] = m_provided_value_node.execute(exec_policy);
     } else if constexpr (std::is_same_v<ExpectedValueType, std::string>) {
-      exec_policy.context()[m_symbol_id] =
+      exec_policy.currentContext()[m_symbol_id] =
         std::to_string(std::get<ProvidedValueType>(m_provided_value_node.execute(exec_policy)));
     } else {
-      exec_policy.context()[m_symbol_id] =
+      exec_policy.currentContext()[m_symbol_id] =
         static_cast<ExpectedValueType>(std::get<ProvidedValueType>(m_provided_value_node.execute(exec_policy)));
     }
     return {};
@@ -66,6 +64,7 @@ class FunctionProcessor : public INodeProcessor
 {
  private:
   const size_t m_context_size;
+  const int32_t m_context_id;
 
   std::vector<std::unique_ptr<INodeProcessor>> m_argument_processors;
   std::vector<std::unique_ptr<INodeProcessor>> m_function_expression_processors;
@@ -86,8 +85,10 @@ class FunctionProcessor : public INodeProcessor
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    ExecutionPolicy context_exec_policy{exec_policy.jumpType(),
-                                        std::make_shared<ExecutionPolicy::Context>(m_context_size)};
+    ExecutionPolicy context_exec_policy{exec_policy,
+                                        ExecutionPolicy::Context{m_context_id,
+                                                                 std::make_shared<ExecutionPolicy::Context::Values>(
+                                                                   m_context_size)}};
 
     for (auto& argument_processor : m_argument_processors) {
       argument_processor->execute(context_exec_policy);
@@ -106,7 +107,7 @@ class FunctionProcessor : public INodeProcessor
     }
   }
 
-  FunctionProcessor(const size_t& context_size) : m_context_size{context_size} {}
+  FunctionProcessor(const SymbolTable::Context& context) : m_context_size{context.size()}, m_context_id{context.id()} {}
 };
 
 #endif   // FUNCTION_PROCESSOR_HPP
diff --git a/src/language/node_processor/LocalNameProcessor.hpp b/src/language/node_processor/LocalNameProcessor.hpp
index 30d78b2da..072e1ce95 100644
--- a/src/language/node_processor/LocalNameProcessor.hpp
+++ b/src/language/node_processor/LocalNameProcessor.hpp
@@ -11,12 +11,13 @@ class LocalNameProcessor final : public INodeProcessor
  private:
   ASTNode& m_node;
   uint64_t m_value_id;
+  int32_t m_context_id;
 
  public:
   DataVariant
   execute(ExecutionPolicy& exec_policy)
   {
-    return exec_policy.context()[m_value_id];
+    return exec_policy.contextOfId(m_context_id)[m_value_id];
   }
 
   LocalNameProcessor(ASTNode& node) : m_node{node}
@@ -24,7 +25,8 @@ class LocalNameProcessor final : public INodeProcessor
     const std::string& symbol = m_node.string();
     auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.begin());
     Assert(found);
-    m_value_id = std::get<uint64_t>(i_symbol->attributes().value());
+    m_value_id   = std::get<uint64_t>(i_symbol->attributes().value());
+    m_context_id = i_symbol->attributes().contextId();
   }
 };
 
diff --git a/tests/test_ExecutionPolicy.cpp b/tests/test_ExecutionPolicy.cpp
index f69dc72c0..33d110f9c 100644
--- a/tests/test_ExecutionPolicy.cpp
+++ b/tests/test_ExecutionPolicy.cpp
@@ -11,21 +11,21 @@ TEST_CASE("ExecutionPolicy", "[language]")
   ExecutionPolicy exec_policy;
   SECTION("no jump")
   {
-    exec_policy = ExecutionPolicy{ExecutionPolicy::JumpType::no_jump, nullptr};
+    exec_policy = ExecutionPolicy{ExecutionPolicy{}, ExecutionPolicy::JumpType::no_jump};
     REQUIRE(exec_policy.exec() == true);
     REQUIRE(exec_policy.jumpType() == ExecutionPolicy::JumpType::no_jump);
   }
 
   SECTION("break jump")
   {
-    exec_policy = ExecutionPolicy{ExecutionPolicy::JumpType::break_jump, nullptr};
+    exec_policy = ExecutionPolicy{ExecutionPolicy{}, ExecutionPolicy::JumpType::break_jump};
     REQUIRE(exec_policy.exec() == false);
     REQUIRE(exec_policy.jumpType() == ExecutionPolicy::JumpType::break_jump);
   }
 
   SECTION("continue jump")
   {
-    exec_policy = ExecutionPolicy{ExecutionPolicy::JumpType::continue_jump, nullptr};
+    exec_policy = ExecutionPolicy{ExecutionPolicy{}, ExecutionPolicy::JumpType::continue_jump};
     REQUIRE(exec_policy.exec() == false);
     REQUIRE(exec_policy.jumpType() == ExecutionPolicy::JumpType::continue_jump);
   }
-- 
GitLab