refactor comprehensions by removing separate target node (to simplify a future length...
authorStefan Behnel <stefan_ml@behnel.de>
Mon, 18 Mar 2013 20:59:49 +0000 (21:59 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Mon, 18 Mar 2013 20:59:49 +0000 (21:59 +0100)
--HG--
extra : rebase_source : 476b22eeaeaea1ff69ee8069328fb47ffe18ea20

Cython/Compiler/ExprNodes.py
Cython/Compiler/FlowControl.py
Cython/Compiler/Optimize.py
Cython/Compiler/Parsing.py

index 2e5c133..ec4d52b 100755 (executable)
@@ -6103,11 +6103,14 @@ class ScopedExprNode(ExprNode):
 
 
 class ComprehensionNode(ScopedExprNode):
-    subexprs = ["target"]
+    # A list/set/dict comprehension
+
     child_attrs = ["loop"]
 
+    is_temp = True
+
     def infer_type(self, env):
-        return self.target.infer_type(env)
+        return self.type
 
     def analyse_declarations(self, env):
         self.append.target = self # this is used in the PyList_Append of the inner loop
@@ -6117,8 +6120,6 @@ class ComprehensionNode(ScopedExprNode):
         self.loop.analyse_declarations(env)
 
     def analyse_types(self, env):
-        self.target = self.target.analyse_expressions(env)
-        self.type = self.target.type
         if not self.has_local_scope:
             self.loop = self.loop.analyse_expressions(env)
         return self
@@ -6131,13 +6132,23 @@ class ComprehensionNode(ScopedExprNode):
     def may_be_none(self):
         return False
 
-    def calculate_result_code(self):
-        return self.target.result()
-
     def generate_result_code(self, code):
         self.generate_operation_code(code)
 
     def generate_operation_code(self, code):
+        if self.type is Builtin.list_type:
+            create_code = 'PyList_New(0)'
+        elif self.type is Builtin.set_type:
+            create_code = 'PySet_New(NULL)'
+        elif self.type is Builtin.dict_type:
+            create_code = 'PyDict_New()'
+        else:
+            raise InternalError("illegal type for comprehension: %s" % self.type)
+        code.putln('%s = %s; %s' % (
+            self.result(), create_code,
+            code.error_goto_if_null(self.result(), self.pos)))
+
+        code.put_gotref(self.result())
         self.loop.generate_execution_code(code)
 
     def annotate(self, code):
@@ -6149,6 +6160,7 @@ class ComprehensionAppendNode(Node):
     # target must not be in child_attrs/subexprs
 
     child_attrs = ['expr']
+    target = None
 
     type = PyrexTypes.c_int_type
 
index 4a218d7..7c812cb 100644 (file)
@@ -1246,7 +1246,6 @@ class ControlFlowAnalysis(CythonTransform):
             self.env_stack.append(self.env)
             self.env = node.expr_scope
         # Skip append node here
-        self._visit(node.target)
         self._visit(node.loop)
         if node.expr_scope:
             self.env = self.env_stack.pop()
index 98aa743..ea11aa3 100644 (file)
@@ -1404,7 +1404,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         if len(pos_args) != 1:
             return node
         if isinstance(pos_args[0], ExprNodes.ComprehensionNode) \
-               and pos_args[0].target.type is Builtin.list_type:
+               and pos_args[0].type is Builtin.list_type:
             listcomp_node = pos_args[0]
             loop_node = listcomp_node.loop
         elif isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
@@ -1414,18 +1414,17 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             if yield_expression is None:
                 return node
 
-            target = ExprNodes.ListNode(node.pos, args = [])
             append_node = ExprNodes.ComprehensionAppendNode(
-                yield_expression.pos, expr = yield_expression,
-                target = ExprNodes.CloneNode(target))
+                yield_expression.pos, expr = yield_expression)
 
             Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
 
             listcomp_node = ExprNodes.ComprehensionNode(
-                gen_expr_node.pos, loop = loop_node, target = target,
+                gen_expr_node.pos, loop = loop_node,
                 append = append_node, type = Builtin.list_type,
                 expr_scope = gen_expr_node.expr_scope,
                 has_local_scope = True)
+            append_node.target = listcomp_node
         else:
             return node
 
@@ -1550,7 +1549,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         # the items into a list and then copy them into a tuple of the
         # final size.  This takes up to twice as much memory, but will
         # have to do until we have real support for genexps.
-        result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
+        result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
         if result is not node:
             return ExprNodes.AsTupleNode(node.pos, arg=result)
         return node
@@ -1558,14 +1557,14 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
     def _handle_simple_function_list(self, node, pos_args):
         if not pos_args:
             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
-        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
+        return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
 
     def _handle_simple_function_set(self, node, pos_args):
         if not pos_args:
             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
-        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
+        return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
 
-    def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
+    def _transform_list_set_genexpr(self, node, pos_args, target_type):
         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
         """
         if len(pos_args) > 1:
@@ -1579,23 +1578,21 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         if yield_expression is None:
             return node
 
-        target_node = container_node_class(node.pos, args=[])
         append_node = ExprNodes.ComprehensionAppendNode(
             yield_expression.pos,
-            expr = yield_expression,
-            target = ExprNodes.CloneNode(target_node))
+            expr = yield_expression)
 
         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
 
-        setcomp = ExprNodes.ComprehensionNode(
+        comp = ExprNodes.ComprehensionNode(
             node.pos,
             has_local_scope = True,
             expr_scope = gen_expr_node.expr_scope,
             loop = loop_node,
             append = append_node,
-            target = target_node)
-        append_node.target = setcomp
-        return setcomp
+            type = target_type)
+        append_node.target = comp
+        return comp
 
     def _handle_simple_function_dict(self, node, pos_args):
         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
@@ -1618,12 +1615,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         if len(yield_expression.args) != 2:
             return node
 
-        target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
         append_node = ExprNodes.DictComprehensionAppendNode(
             yield_expression.pos,
             key_expr = yield_expression.args[0],
-            value_expr = yield_expression.args[1],
-            target = ExprNodes.CloneNode(target_node))
+            value_expr = yield_expression.args[1])
 
         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
 
@@ -1633,7 +1628,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             expr_scope = gen_expr_node.expr_scope,
             loop = loop_node,
             append = append_node,
-            target = target_node)
+            type = Builtin.dict_type)
         append_node.target = dictcomp
         return dictcomp
 
@@ -3245,7 +3240,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         self.visitchildren(node)
         if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
             # loop was pruned already => transform into literal
-            return node.target
+            if node.type is Builtin.list_type:
+                return ExprNodes.ListNode(node.pos, args=[])
+            elif node.type is Builtin.set_type:
+                return ExprNodes.SetNode(node.pos, args=[])
+            elif node.type is Builtin.dict_type:
+                return ExprNodes.DictNode(node.pos, key_value_pairs=[])
         return node
 
     def visit_ForInStatNode(self, node):
index 531dcd4..0f44050 100644 (file)
@@ -7,7 +7,8 @@
 import cython
 cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
                StringEncoding=object, lookup_unicodechar=object, re=object,
-               Future=object, Options=object, error=object, warning=object)
+               Future=object, Options=object, error=object, warning=object,
+               Builtin=object)
 
 import re
 from unicodedata import lookup as lookup_unicodechar
@@ -15,6 +16,7 @@ from unicodedata import lookup as lookup_unicodechar
 from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor
 import Nodes
 import ExprNodes
+import Builtin
 import StringEncoding
 from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes
 from ModuleNode import ModuleNode
@@ -897,13 +899,11 @@ def p_list_maker(s):
         return ExprNodes.ListNode(pos, args = [])
     expr = p_test(s)
     if s.sy == 'for':
-        target = ExprNodes.ListNode(pos, args = [])
-        append = ExprNodes.ComprehensionAppendNode(
-            pos, expr=expr, target=ExprNodes.CloneNode(target))
+        append = ExprNodes.ComprehensionAppendNode(pos, expr=expr)
         loop = p_comp_for(s, append)
         s.expect(']')
         return ExprNodes.ComprehensionNode(
-            pos, loop=loop, append=append, target=target,
+            pos, loop=loop, append=append, type = Builtin.list_type,
             # list comprehensions leak their loop variable in Py2
             has_local_scope = s.context.language_level >= 3)
     else:
@@ -964,13 +964,12 @@ def p_dict_or_set_maker(s):
         return ExprNodes.SetNode(pos, args=values)
     elif s.sy == 'for':
         # set comprehension
-        target = ExprNodes.SetNode(pos, args=[])
         append = ExprNodes.ComprehensionAppendNode(
-            item.pos, expr=item, target=ExprNodes.CloneNode(target))
+            item.pos, expr=item)
         loop = p_comp_for(s, append)
         s.expect('}')
         return ExprNodes.ComprehensionNode(
-            pos, loop=loop, append=append, target=target)
+            pos, loop=loop, append=append, type=Builtin.set_type)
     elif s.sy == ':':
         # dict literal or comprehension
         key = item
@@ -978,14 +977,12 @@ def p_dict_or_set_maker(s):
         value = p_test(s)
         if s.sy == 'for':
             # dict comprehension
-            target = ExprNodes.DictNode(pos, key_value_pairs = [])
             append = ExprNodes.DictComprehensionAppendNode(
-                item.pos, key_expr=key, value_expr=value,
-                target=ExprNodes.CloneNode(target))
+                item.pos, key_expr=key, value_expr=value)
             loop = p_comp_for(s, append)
             s.expect('}')
             return ExprNodes.ComprehensionNode(
-                pos, loop=loop, append=append, target=target)
+                pos, loop=loop, append=append, type=Builtin.dict_type)
         else:
             # dict literal
             items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]