From b2cb180f43a55c264f19ff6d1b0cca2cd9fd8ed7 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Mon, 18 Mar 2013 21:59:49 +0100 Subject: [PATCH] refactor comprehensions by removing separate target node (to simplify a future length-hint optimisation) --HG-- extra : rebase_source : 476b22eeaeaea1ff69ee8069328fb47ffe18ea20 --- Cython/Compiler/ExprNodes.py | 26 +++++++++++++++++++------- Cython/Compiler/FlowControl.py | 1 - Cython/Compiler/Optimize.py | 42 +++++++++++++++++++++--------------------- Cython/Compiler/Parsing.py | 21 +++++++++------------ 4 files changed, 49 insertions(+), 41 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 2e5c133..ec4d52b 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -6103,11 +6103,14 @@ class ScopedExprNode(ExprNode): class ComprehensionNode(ScopedExprNode): - subexprs = ["target"] + # A list/set/dict comprehension + child_attrs = ["loop"] + is_temp = True + def infer_type(self, env): - return self.target.infer_type(env) + return self.type def analyse_declarations(self, env): self.append.target = self # this is used in the PyList_Append of the inner loop @@ -6117,8 +6120,6 @@ class ComprehensionNode(ScopedExprNode): self.loop.analyse_declarations(env) def analyse_types(self, env): - self.target = self.target.analyse_expressions(env) - self.type = self.target.type if not self.has_local_scope: self.loop = self.loop.analyse_expressions(env) return self @@ -6131,13 +6132,23 @@ class ComprehensionNode(ScopedExprNode): def may_be_none(self): return False - def calculate_result_code(self): - return self.target.result() - def generate_result_code(self, code): self.generate_operation_code(code) def generate_operation_code(self, code): + if self.type is Builtin.list_type: + create_code = 'PyList_New(0)' + elif self.type is Builtin.set_type: + create_code = 'PySet_New(NULL)' + elif self.type is Builtin.dict_type: + create_code = 'PyDict_New()' + else: + raise InternalError("illegal type for comprehension: %s" % self.type) + code.putln('%s = %s; %s' % ( + self.result(), create_code, + code.error_goto_if_null(self.result(), self.pos))) + + code.put_gotref(self.result()) self.loop.generate_execution_code(code) def annotate(self, code): @@ -6149,6 +6160,7 @@ class ComprehensionAppendNode(Node): # target must not be in child_attrs/subexprs child_attrs = ['expr'] + target = None type = PyrexTypes.c_int_type diff --git a/Cython/Compiler/FlowControl.py b/Cython/Compiler/FlowControl.py index 4a218d7..7c812cb 100644 --- a/Cython/Compiler/FlowControl.py +++ b/Cython/Compiler/FlowControl.py @@ -1246,7 +1246,6 @@ class ControlFlowAnalysis(CythonTransform): self.env_stack.append(self.env) self.env = node.expr_scope # Skip append node here - self._visit(node.target) self._visit(node.loop) if node.expr_scope: self.env = self.env_stack.pop() diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 98aa743..ea11aa3 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1404,7 +1404,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if len(pos_args) != 1: return node if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \ - and pos_args[0].target.type is Builtin.list_type: + and pos_args[0].type is Builtin.list_type: listcomp_node = pos_args[0] loop_node = listcomp_node.loop elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): @@ -1414,18 +1414,17 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if yield_expression is None: return node - target = ExprNodes.ListNode(node.pos, args = []) append_node = ExprNodes.ComprehensionAppendNode( - yield_expression.pos, expr = yield_expression, - target = ExprNodes.CloneNode(target)) + yield_expression.pos, expr = yield_expression) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) listcomp_node = ExprNodes.ComprehensionNode( - gen_expr_node.pos, loop = loop_node, target = target, + gen_expr_node.pos, loop = loop_node, append = append_node, type = Builtin.list_type, expr_scope = gen_expr_node.expr_scope, has_local_scope = True) + append_node.target = listcomp_node else: return node @@ -1550,7 +1549,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): # the items into a list and then copy them into a tuple of the # final size. This takes up to twice as much memory, but will # have to do until we have real support for genexps. - result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode) + result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) if result is not node: return ExprNodes.AsTupleNode(node.pos, arg=result) return node @@ -1558,14 +1557,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): def _handle_simple_function_list(self, node, pos_args): if not pos_args: return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) - return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode) + return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) def _handle_simple_function_set(self, node, pos_args): if not pos_args: return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) - return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode) + return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) - def _transform_list_set_genexpr(self, node, pos_args, container_node_class): + def _transform_list_set_genexpr(self, node, pos_args, target_type): """Replace set(genexpr) and list(genexpr) by a literal comprehension. """ if len(pos_args) > 1: @@ -1579,23 +1578,21 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if yield_expression is None: return node - target_node = container_node_class(node.pos, args=[]) append_node = ExprNodes.ComprehensionAppendNode( yield_expression.pos, - expr = yield_expression, - target = ExprNodes.CloneNode(target_node)) + expr = yield_expression) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) - setcomp = ExprNodes.ComprehensionNode( + comp = ExprNodes.ComprehensionNode( node.pos, has_local_scope = True, expr_scope = gen_expr_node.expr_scope, loop = loop_node, append = append_node, - target = target_node) - append_node.target = setcomp - return setcomp + type = target_type) + append_node.target = comp + return comp def _handle_simple_function_dict(self, node, pos_args): """Replace dict( (a,b) for ... ) by a literal { a:b for ... }. @@ -1618,12 +1615,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if len(yield_expression.args) != 2: return node - target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[]) append_node = ExprNodes.DictComprehensionAppendNode( yield_expression.pos, key_expr = yield_expression.args[0], - value_expr = yield_expression.args[1], - target = ExprNodes.CloneNode(target_node)) + value_expr = yield_expression.args[1]) Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node) @@ -1633,7 +1628,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): expr_scope = gen_expr_node.expr_scope, loop = loop_node, append = append_node, - target = target_node) + type = Builtin.dict_type) append_node.target = dictcomp return dictcomp @@ -3245,7 +3240,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): self.visitchildren(node) if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: # loop was pruned already => transform into literal - return node.target + if node.type is Builtin.list_type: + return ExprNodes.ListNode(node.pos, args=[]) + elif node.type is Builtin.set_type: + return ExprNodes.SetNode(node.pos, args=[]) + elif node.type is Builtin.dict_type: + return ExprNodes.DictNode(node.pos, key_value_pairs=[]) return node def visit_ForInStatNode(self, node): diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 531dcd4..0f44050 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -7,7 +7,8 @@ import cython cython.declare(Nodes=object, ExprNodes=object, EncodedString=object, StringEncoding=object, lookup_unicodechar=object, re=object, - Future=object, Options=object, error=object, warning=object) + Future=object, Options=object, error=object, warning=object, + Builtin=object) import re from unicodedata import lookup as lookup_unicodechar @@ -15,6 +16,7 @@ from unicodedata import lookup as lookup_unicodechar from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor import Nodes import ExprNodes +import Builtin import StringEncoding from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes from ModuleNode import ModuleNode @@ -897,13 +899,11 @@ def p_list_maker(s): return ExprNodes.ListNode(pos, args = []) expr = p_test(s) if s.sy == 'for': - target = ExprNodes.ListNode(pos, args = []) - append = ExprNodes.ComprehensionAppendNode( - pos, expr=expr, target=ExprNodes.CloneNode(target)) + append = ExprNodes.ComprehensionAppendNode(pos, expr=expr) loop = p_comp_for(s, append) s.expect(']') return ExprNodes.ComprehensionNode( - pos, loop=loop, append=append, target=target, + pos, loop=loop, append=append, type = Builtin.list_type, # list comprehensions leak their loop variable in Py2 has_local_scope = s.context.language_level >= 3) else: @@ -964,13 +964,12 @@ def p_dict_or_set_maker(s): return ExprNodes.SetNode(pos, args=values) elif s.sy == 'for': # set comprehension - target = ExprNodes.SetNode(pos, args=[]) append = ExprNodes.ComprehensionAppendNode( - item.pos, expr=item, target=ExprNodes.CloneNode(target)) + item.pos, expr=item) loop = p_comp_for(s, append) s.expect('}') return ExprNodes.ComprehensionNode( - pos, loop=loop, append=append, target=target) + pos, loop=loop, append=append, type=Builtin.set_type) elif s.sy == ':': # dict literal or comprehension key = item @@ -978,14 +977,12 @@ def p_dict_or_set_maker(s): value = p_test(s) if s.sy == 'for': # dict comprehension - target = ExprNodes.DictNode(pos, key_value_pairs = []) append = ExprNodes.DictComprehensionAppendNode( - item.pos, key_expr=key, value_expr=value, - target=ExprNodes.CloneNode(target)) + item.pos, key_expr=key, value_expr=value) loop = p_comp_for(s, append) s.expect('}') return ExprNodes.ComprehensionNode( - pos, loop=loop, append=append, target=target) + pos, loop=loop, append=append, type=Builtin.dict_type) else: # dict literal items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)] -- 2.7.4