From db1e13e862893e49bd5d3fc6748015530955938e Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Tue, 1 Jan 2013 22:14:32 +0100 Subject: [PATCH] make AnalyseDeclarationsTransform inherit from EnvTransform to fix inconsistencies in scope tracking --- Cython/Compiler/ParseTreeTransforms.pxd | 4 +-- Cython/Compiler/ParseTreeTransforms.py | 40 +++++++++-------------- Cython/Compiler/Symtab.py | 2 +- Cython/Compiler/Visitor.py | 29 ++++++++++------- tests/run/closure_defaultargs.pyx | 21 ------------ tests/run/cyfunction_defaults.pyx | 58 ++++++++++++++++++++++++++++++++- 6 files changed, 94 insertions(+), 60 deletions(-) delete mode 100644 tests/run/closure_defaultargs.pyx diff --git a/Cython/Compiler/ParseTreeTransforms.pxd b/Cython/Compiler/ParseTreeTransforms.pxd index d88e4c5..58d4795 100644 --- a/Cython/Compiler/ParseTreeTransforms.pxd +++ b/Cython/Compiler/ParseTreeTransforms.pxd @@ -34,10 +34,10 @@ cdef map_starred_assignment(list lhs_targets, list starred_assignments, list lhs #class WithTransform(CythonTransform, SkipDeclarations): #class DecoratorTransform(CythonTransform, SkipDeclarations): -#class AnalyseDeclarationsTransform(CythonTransform): +#class AnalyseDeclarationsTransform(EnvTransform): cdef class AnalyseExpressionsTransform(CythonTransform): - cdef list env_stack + pass cdef class ExpandInplaceOperators(EnvTransform): pass diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 4bcfb30..7c849b2 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1349,7 +1349,7 @@ class ForwardDeclareTypes(CythonTransform): return node -class AnalyseDeclarationsTransform(CythonTransform): +class AnalyseDeclarationsTransform(EnvTransform): basic_property = TreeFragment(u""" property NAME: @@ -1398,11 +1398,12 @@ if VALUE is not None: in_lambda = 0 def __call__(self, root): - self.env_stack = [root.scope] # needed to determine if a cdef var is declared after it's used. self.seen_vars_stack = [] self.fused_error_funcs = set() - return super(AnalyseDeclarationsTransform, self).__call__(root) + super_class = super(AnalyseDeclarationsTransform, self) + self._super_visit_FuncDefNode = super_class.visit_FuncDefNode + return super_class.__call__(root) def visit_NameNode(self, node): self.seen_vars_stack[-1].add(node.name) @@ -1410,24 +1411,18 @@ if VALUE is not None: def visit_ModuleNode(self, node): self.seen_vars_stack.append(set()) - node.analyse_declarations(self.env_stack[-1]) + node.analyse_declarations(self.current_env()) self.visitchildren(node) self.seen_vars_stack.pop() return node def visit_LambdaNode(self, node): self.in_lambda += 1 - node.analyse_declarations(self.env_stack[-1]) + node.analyse_declarations(self.current_env()) self.visitchildren(node) self.in_lambda -= 1 return node - def visit_ClassDefNode(self, node): - self.env_stack.append(node.scope) - self.visitchildren(node) - self.env_stack.pop() - return node - def visit_CClassDefNode(self, node): node = self.visit_ClassDefNode(node) if node.scope and node.scope.implemented: @@ -1548,7 +1543,7 @@ if VALUE is not None: analyse its children (which are in turn normal functions). If we're a normal function, just analyse the body of the function. """ - env = self.env_stack[-1] + env = self.current_env() self.seen_vars_stack.append(set()) lenv = node.local_scope @@ -1567,23 +1562,23 @@ if VALUE is not None: else: node.body.analyse_declarations(lenv) self._handle_nogil_cleanup(lenv, node) - - self.env_stack.append(lenv) - self.visitchildren(node) - self.env_stack.pop() + self._super_visit_FuncDefNode(node) self.seen_vars_stack.pop() return node def visit_DefNode(self, node): node = self.visit_FuncDefNode(node) - env = self.env_stack[-1] + env = self.current_env() if (not isinstance(node, Nodes.DefNode) or node.fused_py_func or node.is_generator_body or not node.needs_assignment_synthesis(env)): return node return [node, self._synthesize_assignment(node, env)] + def visit_GeneratorBodyDefNode(self, node): + return self.visit_FuncDefNode(node) + def _synthesize_assignment(self, node, env): # Synthesize assignment node and put it right after defnode genv = env @@ -1622,15 +1617,15 @@ if VALUE is not None: return assmt def visit_ScopedExprNode(self, node): - env = self.env_stack[-1] + env = self.current_env() node.analyse_declarations(env) # the node may or may not have a local scope if node.has_local_scope: self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) - self.env_stack.append(node.expr_scope) + self.enter_scope(node, node.expr_scope) node.analyse_scoped_declarations(node.expr_scope) self.visitchildren(node) - self.env_stack.pop() + self.exit_scope() self.seen_vars_stack.pop() else: node.analyse_scoped_declarations(env) @@ -1639,7 +1634,7 @@ if VALUE is not None: def visit_TempResultFromStatNode(self, node): self.visitchildren(node) - node.analyse_declarations(self.env_stack[-1]) + node.analyse_declarations(self.current_env()) return node def visit_CppClassNode(self, node): @@ -1804,18 +1799,15 @@ if VALUE is not None: class AnalyseExpressionsTransform(CythonTransform): def visit_ModuleNode(self, node): - self.env_stack = [node.scope] node.scope.infer_types() node.body.analyse_expressions(node.scope) self.visitchildren(node) return node def visit_FuncDefNode(self, node): - self.env_stack.append(node.local_scope) node.local_scope.infer_types() node.body.analyse_expressions(node.local_scope) self.visitchildren(node) - self.env_stack.pop() return node def visit_ScopedExprNode(self, node): diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 551db98..afa4da7 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -201,7 +201,7 @@ class Entry(object): self.defining_entry = self def __repr__(self): - return "%s(name=%s, type=%s)" % (type(self).__name__, self.name, self.type) + return "%s(<%x>, name=%s, type=%s)" % (type(self).__name__, id(self), self.name, self.type) def redeclared(self, pos): error(pos, "'%s' does not match previous declaration" % self.name) diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index ce576e3..d5810b1 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -321,7 +321,8 @@ class EnvTransform(CythonTransform): This transformation keeps a stack of the environments. """ def __call__(self, root): - self.env_stack = [(root, root.scope)] + self.env_stack = [] + self.enter_scope(root, root.scope) return super(EnvTransform, self).__call__(root) def current_env(self): @@ -333,10 +334,16 @@ class EnvTransform(CythonTransform): def global_scope(self): return self.current_env().global_scope() + def enter_scope(self, node, scope): + self.env_stack.append((node, scope)) + + def exit_scope(self): + self.env_stack.pop() + def visit_FuncDefNode(self, node): - self.env_stack.append((node, node.local_scope)) + self.enter_scope(node, node.local_scope) self.visitchildren(node) - self.env_stack.pop() + self.exit_scope() return node def visit_GeneratorBodyDefNode(self, node): @@ -344,22 +351,22 @@ class EnvTransform(CythonTransform): return node def visit_ClassDefNode(self, node): - self.env_stack.append((node, node.scope)) + self.enter_scope(node, node.scope) self.visitchildren(node) - self.env_stack.pop() + self.exit_scope() return node def visit_CStructOrUnionDefNode(self, node): - self.env_stack.append((node, node.scope)) + self.enter_scope(node, node.scope) self.visitchildren(node) - self.env_stack.pop() + self.exit_scope() return node def visit_ScopedExprNode(self, node): if node.expr_scope: - self.env_stack.append((node, node.expr_scope)) + self.enter_scope(node, node.expr_scope) self.visitchildren(node) - self.env_stack.pop() + self.exit_scope() else: self.visitchildren(node) return node @@ -369,9 +376,9 @@ class EnvTransform(CythonTransform): if node.default: attrs = [ attr for attr in node.child_attrs if attr != 'default' ] self.visitchildren(node, attrs) - self.env_stack.append((node, self.current_env().outer_scope)) + self.enter_scope(node, self.current_env().outer_scope) self.visitchildren(node, ('default',)) - self.env_stack.pop() + self.exit_scope() else: self.visitchildren(node) return node diff --git a/tests/run/closure_defaultargs.pyx b/tests/run/closure_defaultargs.pyx deleted file mode 100644 index c705cbf..0000000 --- a/tests/run/closure_defaultargs.pyx +++ /dev/null @@ -1,21 +0,0 @@ -# mode: run -# tag: closures - -cimport cython - -@cython.test_fail_if_path_exists( - '//NameNode[@entry.in_closure = True]', - '//NameNode[@entry.from_closure = True]') -def test_func_default(): - """ - >>> func = test_func_default() - >>> func() - 1 - >>> func(2) - 2 - """ - def default(): - return 1 - def func(arg=default()): - return arg - return func diff --git a/tests/run/cyfunction_defaults.pyx b/tests/run/cyfunction_defaults.pyx index cb628e7..8c5cda3 100644 --- a/tests/run/cyfunction_defaults.pyx +++ b/tests/run/cyfunction_defaults.pyx @@ -1,6 +1,6 @@ # cython: binding=True # mode: run -# tag: cyfunction +# tag: cyfunction, closures cimport cython import sys @@ -131,3 +131,59 @@ def test_dynamic_defaults_fused(): for i, f in enumerate(funcs): print "i", i, "func result", f(1.0), "defaults", get_defaults(f) + +@cython.test_fail_if_path_exists( + '//NameNode[@entry.in_closure = True]', + '//NameNode[@entry.from_closure = True]') +def test_func_default_inlined(): + """ + Make sure we don't accidentally generate a closure. + + >>> func = test_func_default_inlined() + >>> func() + 1 + >>> func(2) + 2 + """ + def default(): + return 1 + def func(arg=default()): + return arg + return func + + +@cython.test_fail_if_path_exists( + '//NameNode[@entry.in_closure = True]', + '//NameNode[@entry.from_closure = True]') +def test_func_default_scope(): + """ + Test that the default value expression is evaluated in the outer scope. + + >>> func = test_func_default_scope() + 3 + >>> func() + [0, 1, 2, 3] + >>> func(2) + 2 + """ + i = -1 + def func(arg=[ i for i in range(4) ]): + return arg + print i # list comps leak in Py2 mode => i == 3 + return func + + +def test_func_default_scope_local(): + """ + >>> func = test_func_default_scope_local() + -1 + >>> func() + [0, 1, 2, 3] + >>> func(2) + 2 + """ + i = -1 + def func(arg=list(i for i in range(4))): + return arg + print i # genexprs don't leak + return func -- 2.7.4