make AnalyseDeclarationsTransform inherit from EnvTransform to fix inconsistencies...
authorStefan Behnel <stefan_ml@behnel.de>
Tue, 1 Jan 2013 21:14:32 +0000 (22:14 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Tue, 1 Jan 2013 21:14:32 +0000 (22:14 +0100)
Cython/Compiler/ParseTreeTransforms.pxd
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py
Cython/Compiler/Visitor.py
tests/run/closure_defaultargs.pyx [deleted file]
tests/run/cyfunction_defaults.pyx

index d88e4c5..58d4795 100644 (file)
@@ -34,10 +34,10 @@ cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs
 #class WithTransform(CythonTransform, SkipDeclarations):
 #class DecoratorTransform(CythonTransform, SkipDeclarations):
 
-#class AnalyseDeclarationsTransform(CythonTransform):
+#class AnalyseDeclarationsTransform(EnvTransform):
 
 cdef class AnalyseExpressionsTransform(CythonTransform):
-    cdef list env_stack
+    pass
 
 cdef class ExpandInplaceOperators(EnvTransform):
     pass
index 4bcfb30..7c849b2 100644 (file)
@@ -1349,7 +1349,7 @@ class ForwardDeclareTypes(CythonTransform):
         return node
 
 
-class AnalyseDeclarationsTransform(CythonTransform):
+class AnalyseDeclarationsTransform(EnvTransform):
 
     basic_property = TreeFragment(u"""
 property NAME:
@@ -1398,11 +1398,12 @@ if VALUE is not None:
     in_lambda = 0
 
     def __call__(self, root):
-        self.env_stack = [root.scope]
         # needed to determine if a cdef var is declared after it's used.
         self.seen_vars_stack = []
         self.fused_error_funcs = set()
-        return super(AnalyseDeclarationsTransform, self).__call__(root)
+        super_class = super(AnalyseDeclarationsTransform, self)
+        self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
+        return super_class.__call__(root)
 
     def visit_NameNode(self, node):
         self.seen_vars_stack[-1].add(node.name)
@@ -1410,24 +1411,18 @@ if VALUE is not None:
 
     def visit_ModuleNode(self, node):
         self.seen_vars_stack.append(set())
-        node.analyse_declarations(self.env_stack[-1])
+        node.analyse_declarations(self.current_env())
         self.visitchildren(node)
         self.seen_vars_stack.pop()
         return node
 
     def visit_LambdaNode(self, node):
         self.in_lambda += 1
-        node.analyse_declarations(self.env_stack[-1])
+        node.analyse_declarations(self.current_env())
         self.visitchildren(node)
         self.in_lambda -= 1
         return node
 
-    def visit_ClassDefNode(self, node):
-        self.env_stack.append(node.scope)
-        self.visitchildren(node)
-        self.env_stack.pop()
-        return node
-
     def visit_CClassDefNode(self, node):
         node = self.visit_ClassDefNode(node)
         if node.scope and node.scope.implemented:
@@ -1548,7 +1543,7 @@ if VALUE is not None:
         analyse its children (which are in turn normal functions). If we're a
         normal function, just analyse the body of the function.
         """
-        env = self.env_stack[-1]
+        env = self.current_env()
 
         self.seen_vars_stack.append(set())
         lenv = node.local_scope
@@ -1567,23 +1562,23 @@ if VALUE is not None:
         else:
             node.body.analyse_declarations(lenv)
             self._handle_nogil_cleanup(lenv, node)
-
-            self.env_stack.append(lenv)
-            self.visitchildren(node)
-            self.env_stack.pop()
+            self._super_visit_FuncDefNode(node)
 
         self.seen_vars_stack.pop()
         return node
 
     def visit_DefNode(self, node):
         node = self.visit_FuncDefNode(node)
-        env = self.env_stack[-1]
+        env = self.current_env()
         if (not isinstance(node, Nodes.DefNode) or
             node.fused_py_func or node.is_generator_body or
             not node.needs_assignment_synthesis(env)):
             return node
         return [node, self._synthesize_assignment(node, env)]
 
+    def visit_GeneratorBodyDefNode(self, node):
+        return self.visit_FuncDefNode(node)
+
     def _synthesize_assignment(self, node, env):
         # Synthesize assignment node and put it right after defnode
         genv = env
@@ -1622,15 +1617,15 @@ if VALUE is not None:
         return assmt
 
     def visit_ScopedExprNode(self, node):
-        env = self.env_stack[-1]
+        env = self.current_env()
         node.analyse_declarations(env)
         # the node may or may not have a local scope
         if node.has_local_scope:
             self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
-            self.env_stack.append(node.expr_scope)
+            self.enter_scope(node, node.expr_scope)
             node.analyse_scoped_declarations(node.expr_scope)
             self.visitchildren(node)
-            self.env_stack.pop()
+            self.exit_scope()
             self.seen_vars_stack.pop()
         else:
             node.analyse_scoped_declarations(env)
@@ -1639,7 +1634,7 @@ if VALUE is not None:
 
     def visit_TempResultFromStatNode(self, node):
         self.visitchildren(node)
-        node.analyse_declarations(self.env_stack[-1])
+        node.analyse_declarations(self.current_env())
         return node
 
     def visit_CppClassNode(self, node):
@@ -1804,18 +1799,15 @@ if VALUE is not None:
 class AnalyseExpressionsTransform(CythonTransform):
 
     def visit_ModuleNode(self, node):
-        self.env_stack = [node.scope]
         node.scope.infer_types()
         node.body.analyse_expressions(node.scope)
         self.visitchildren(node)
         return node
 
     def visit_FuncDefNode(self, node):
-        self.env_stack.append(node.local_scope)
         node.local_scope.infer_types()
         node.body.analyse_expressions(node.local_scope)
         self.visitchildren(node)
-        self.env_stack.pop()
         return node
 
     def visit_ScopedExprNode(self, node):
index 551db98..afa4da7 100644 (file)
@@ -201,7 +201,7 @@ class Entry(object):
         self.defining_entry = self
 
     def __repr__(self):
-        return "%s(name=%s, type=%s)" % (type(self).__name__, self.name, self.type)
+        return "%s(<%x>, name=%s, type=%s)" % (type(self).__name__, id(self), self.name, self.type)
 
     def redeclared(self, pos):
         error(pos, "'%s' does not match previous declaration" % self.name)
index ce576e3..d5810b1 100644 (file)
@@ -321,7 +321,8 @@ class EnvTransform(CythonTransform):
     This transformation keeps a stack of the environments.
     """
     def __call__(self, root):
-        self.env_stack = [(root, root.scope)]
+        self.env_stack = []
+        self.enter_scope(root, root.scope)
         return super(EnvTransform, self).__call__(root)
 
     def current_env(self):
@@ -333,10 +334,16 @@ class EnvTransform(CythonTransform):
     def global_scope(self):
         return self.current_env().global_scope()
 
+    def enter_scope(self, node, scope):
+        self.env_stack.append((node, scope))
+
+    def exit_scope(self):
+        self.env_stack.pop()
+
     def visit_FuncDefNode(self, node):
-        self.env_stack.append((node, node.local_scope))
+        self.enter_scope(node, node.local_scope)
         self.visitchildren(node)
-        self.env_stack.pop()
+        self.exit_scope()
         return node
 
     def visit_GeneratorBodyDefNode(self, node):
@@ -344,22 +351,22 @@ class EnvTransform(CythonTransform):
         return node
 
     def visit_ClassDefNode(self, node):
-        self.env_stack.append((node, node.scope))
+        self.enter_scope(node, node.scope)
         self.visitchildren(node)
-        self.env_stack.pop()
+        self.exit_scope()
         return node
 
     def visit_CStructOrUnionDefNode(self, node):
-        self.env_stack.append((node, node.scope))
+        self.enter_scope(node, node.scope)
         self.visitchildren(node)
-        self.env_stack.pop()
+        self.exit_scope()
         return node
 
     def visit_ScopedExprNode(self, node):
         if node.expr_scope:
-            self.env_stack.append((node, node.expr_scope))
+            self.enter_scope(node, node.expr_scope)
             self.visitchildren(node)
-            self.env_stack.pop()
+            self.exit_scope()
         else:
             self.visitchildren(node)
         return node
@@ -369,9 +376,9 @@ class EnvTransform(CythonTransform):
         if node.default:
             attrs = [ attr for attr in node.child_attrs if attr != 'default' ]
             self.visitchildren(node, attrs)
-            self.env_stack.append((node, self.current_env().outer_scope))
+            self.enter_scope(node, self.current_env().outer_scope)
             self.visitchildren(node, ('default',))
-            self.env_stack.pop()
+            self.exit_scope()
         else:
             self.visitchildren(node)
         return node
diff --git a/tests/run/closure_defaultargs.pyx b/tests/run/closure_defaultargs.pyx
deleted file mode 100644 (file)
index c705cbf..0000000
+++ /dev/null
@@ -1,21 +0,0 @@
-# mode: run
-# tag: closures
-
-cimport cython
-
-@cython.test_fail_if_path_exists(
-    '//NameNode[@entry.in_closure = True]',
-    '//NameNode[@entry.from_closure = True]')
-def test_func_default():
-    """
-    >>> func = test_func_default()
-    >>> func()
-    1
-    >>> func(2)
-    2
-    """
-    def default():
-        return 1
-    def func(arg=default()):
-        return arg
-    return func
index cb628e7..8c5cda3 100644 (file)
@@ -1,6 +1,6 @@
 # cython: binding=True
 # mode: run
-# tag: cyfunction
+# tag: cyfunction, closures
 
 cimport cython
 import sys
@@ -131,3 +131,59 @@ def test_dynamic_defaults_fused():
     for i, f in enumerate(funcs):
         print "i", i, "func result", f(1.0), "defaults", get_defaults(f)
 
+
+@cython.test_fail_if_path_exists(
+    '//NameNode[@entry.in_closure = True]',
+    '//NameNode[@entry.from_closure = True]')
+def test_func_default_inlined():
+    """
+    Make sure we don't accidentally generate a closure.
+
+    >>> func = test_func_default_inlined()
+    >>> func()
+    1
+    >>> func(2)
+    2
+    """
+    def default():
+        return 1
+    def func(arg=default()):
+        return arg
+    return func
+
+
+@cython.test_fail_if_path_exists(
+    '//NameNode[@entry.in_closure = True]',
+    '//NameNode[@entry.from_closure = True]')
+def test_func_default_scope():
+    """
+    Test that the default value expression is evaluated in the outer scope.
+
+    >>> func = test_func_default_scope()
+    3
+    >>> func()
+    [0, 1, 2, 3]
+    >>> func(2)
+    2
+    """
+    i = -1
+    def func(arg=[ i for i in range(4) ]):
+        return arg
+    print i  # list comps leak in Py2 mode => i == 3
+    return func
+
+
+def test_func_default_scope_local():
+    """
+    >>> func = test_func_default_scope_local()
+    -1
+    >>> func()
+    [0, 1, 2, 3]
+    >>> func(2)
+    2
+    """
+    i = -1
+    def func(arg=list(i for i in range(4))):
+        return arg
+    print i  # genexprs don't leak
+    return func