fix with-statement when context manager comes from an inlined def-function call
authorStefan Behnel <stefan_ml@behnel.de>
Mon, 31 Dec 2012 11:56:04 +0000 (12:56 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Mon, 31 Dec 2012 11:56:04 +0000 (12:56 +0100)
Cython/Compiler/Optimize.py
Cython/Compiler/Visitor.py
tests/run/closure_inlining.pyx

index e9f29ab..872ead0 100644 (file)
@@ -1668,7 +1668,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             return node
         return kwargs
 
-class InlineDefNodeCalls(Visitor.EnvTransform):
+
+class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
     def get_constant_value_node(self, name_node):
@@ -1696,7 +1697,7 @@ class InlineDefNodeCalls(Visitor.EnvTransform):
             node.pos, function_name=function_name,
             function=function, args=node.args)
         if inlined.can_be_inlined():
-            return inlined
+            return self.replace(node, inlined)
         return node
 
 
index e407095..1d7b556 100644 (file)
@@ -374,6 +374,42 @@ class EnvTransform(CythonTransform):
         return node
 
 
+class NodeRefCleanupMixin(object):
+    """
+    Clean up references to nodes that were replaced.
+
+    NOTE: this implementation assumes that the replacement is
+    done first, before hitting any further references during
+    normal tree traversal.  This needs to be arranged by calling
+    "self.visitchildren()" at a proper place in the transform
+    and by ordering the "child_attrs" of nodes appropriately.
+    """
+    def __init__(self, *args):
+        super(NodeRefCleanupMixin, self).__init__(*args)
+        self._replacements = {}
+
+    def visit_CloneNode(self, node):
+        arg = node.arg
+        if arg not in self._replacements:
+            self.visitchildren(node)
+            arg = node.arg
+        node.arg = self._replacements.get(arg, arg)
+        return node
+
+    def visit_ResultRefNode(self, node):
+        expr = node.expression
+        if expr is None or expr not in self._replacements:
+            self.visitchildren(node)
+            expr = node.expression
+        if expr is not None:
+            node.expression = self._replacements.get(expr, expr)
+        return node
+
+    def replace(self, node, replacement):
+        self._replacements[node] = replacement
+        return replacement
+
+
 class MethodDispatcherTransform(EnvTransform):
     """
     Base class for transformations that want to intercept on specific
index 8a0e387..7ea7718 100644 (file)
@@ -128,3 +128,22 @@ def test_redef(redefine):
     else:
         assert inner != inner2
     return inner()
+
+
+def test_with_statement():
+    """
+    >>> test_with_statement()
+    enter
+    running
+    exit
+    """
+    def make_context_manager():
+        class CM(object):
+            def __enter__(self):
+                print "enter"
+            def __exit__(self, *args):
+                print "exit"
+        return CM()
+
+    with make_context_manager():
+        print "running"