fix compiler crash for generator expressions with a constant False condition
authorStefan Behnel <stefan_ml@behnel.de>
Fri, 9 Nov 2012 20:26:24 +0000 (21:26 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Fri, 9 Nov 2012 20:26:24 +0000 (21:26 +0100)
--HG--
extra : transplant_source : C%C7%2Ak%C4%89%DA%C1f%85%86%0D%9E%7F_%B4%17%D2t%40

Cython/Compiler/Optimize.py
tests/run/generator_expressions.pyx
tests/run/generator_expressions_nested.pyx [new file with mode: 0644]

index b57c983..de842e9 100644 (file)
@@ -3117,11 +3117,30 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
                 break
             else:
                 assert condition_result == False
+                # prevent killing generators, but simplify them as much as possible
+                yield_expr = self._find_genexpr_yield(if_clause.body)
+                if yield_expr is not None:
+                    if_clause.condition = ExprNodes.BoolNode(if_clause.condition.pos, value=False)
+                    yield_expr.arg = ExprNodes.NoneNode(yield_expr.arg.pos)
+                    if_clauses.append(if_clause)
+                else:
+                    # False clauses outside of generators can safely be deleted
+                    pass
         if not if_clauses:
             return node.else_clause
         node.if_clauses = if_clauses
         return node
 
+    def _find_genexpr_yield(self, node):
+        body_node_types = (Nodes.ForInStatNode, Nodes.IfStatNode)
+        while isinstance(node, body_node_types):
+            node = node.body
+        if isinstance(node, Nodes.ExprStatNode):
+            node = node.expr
+            if isinstance(node, ExprNodes.YieldExprNode):
+                return node
+        return None
+
     # in the future, other nodes can have their own handler method here
     # that can replace them with a constant result node
 
index f525da7..e9a8c13 100644 (file)
@@ -21,6 +21,16 @@ def genexpr_if():
     assert x == 'abc' # don't leak
     return result
 
+def genexpr_if_false():
+    """
+    >>> genexpr_if_false()
+    []
+    """
+    x = 'abc'
+    result = list( x*2 for x in range(5) if False )
+    assert x == 'abc' # don't leak
+    return result
+
 def genexpr_with_lambda():
     """
     >>> genexpr_with_lambda()
diff --git a/tests/run/generator_expressions_nested.pyx b/tests/run/generator_expressions_nested.pyx
new file mode 100644 (file)
index 0000000..8fb2619
--- /dev/null
@@ -0,0 +1,74 @@
+# mode: run
+# cython: language_level=3
+
+"""
+Adapted from CPython's test_grammar.py
+"""
+
+def genexpr_simple():
+    """
+    >>> sum([ x**2 for x in range(10) ])
+    285
+    >>> sum(genexpr_simple())
+    285
+    """
+    return (x**2 for x in range(10))
+
+def genexpr_conditional():
+    """
+    >>> sum([ x*x for x in range(10) if x%2 ])
+    165
+    >>> sum(genexpr_conditional())
+    165
+    """
+    return (x*x for x in range(10) if x%2)
+
+def genexpr_nested2():
+    """
+    >>> sum([x for x in range(10)])
+    45
+    >>> sum(genexpr_nested2())
+    45
+    """
+    return (x for x in (y for y in range(10)))
+
+def genexpr_nested3():
+    """
+    >>> sum([x for x in range(10)])
+    45
+    >>> sum(genexpr_nested3())
+    45
+    """
+    return (x for x in (y for y in (z for z in range(10))))
+
+def genexpr_nested_listcomp():
+    """
+    >>> sum([x for x in range(10)])
+    45
+    >>> sum(genexpr_nested_listcomp())
+    45
+    """
+    return (x for x in [y for y in (z for z in range(10))])
+
+def genexpr_nested_conditional():
+    """
+    >>> sum([ x for x in [y for y in [z for z in range(10) if True]] if True ])
+    45
+    >>> sum(genexpr_nested_conditional())
+    45
+    """
+    return (x for x in (y for y in (z for z in range(10) if True)) if True)
+
+def genexpr_nested2_conditional_empty():
+    """
+    >>> sum(genexpr_nested2_conditional_empty())
+    0
+    """
+    return (y for y in (z for z in range(10) if True) if False)
+
+def genexpr_nested3_conditional_empty():
+    """
+    >>> sum(genexpr_nested3_conditional_empty())
+    0
+    """
+    return (x for x in (y for y in (z for z in range(10) if True) if False) if True)