constant fold some more special cases for list/tuple multiplication
authorStefan Behnel <stefan_ml@behnel.de>
Sun, 1 Dec 2013 15:31:55 +0000 (16:31 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Sun, 1 Dec 2013 15:31:55 +0000 (16:31 +0100)
Cython/Compiler/Optimize.py
tests/run/constant_folding.py

index 95893f1..3742088 100644 (file)
@@ -3265,25 +3265,22 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         return new_node
 
     def visit_MulNode(self, node):
+        self._calculate_const(node)
         if isinstance(node.operand1, (ExprNodes.ListNode, ExprNodes.TupleNode)):
-            sequence_node = node.operand1
-            factor = node.operand2
-            self._calculate_const(factor)
-            if factor.constant_result != 1 and sequence_node.args:
-                sequence_node.mult_factor = factor
-            self.visitchildren(sequence_node)
-            return sequence_node
+            return self._calculate_constant_seq(node.operand1, node.operand2)
         if isinstance(node.operand1, ExprNodes.IntNode) and \
-               isinstance(node.operand2, (ExprNodes.ListNode, ExprNodes.TupleNode)):
-            sequence_node = node.operand2
-            factor = node.operand1
-            self._calculate_const(factor)
-            if factor.constant_result != 1 and sequence_node.args:
-                sequence_node.mult_factor = factor
-            self.visitchildren(sequence_node)
-            return sequence_node
+                isinstance(node.operand2, (ExprNodes.ListNode, ExprNodes.TupleNode)):
+            return self._calculate_constant_seq(node.operand2, node.operand1)
         return self.visit_BinopNode(node)
 
+    def _calculate_constant_seq(self, sequence_node, factor):
+        if factor.constant_result != 1 and sequence_node.args:
+            if isinstance(factor.constant_result, (int, long)) and factor.constant_result <= 0:
+                del sequence_node.args[:]
+            else:
+                sequence_node.mult_factor = factor
+        return sequence_node
+
     def visit_PrimaryCmpNode(self, node):
         self._calculate_const(node)
         if node.constant_result is not ExprNodes.not_a_constant:
index 0f2a136..1a65235 100644 (file)
@@ -281,7 +281,7 @@ def for_in_nested_listcomp():
 
 
 @cython.test_fail_if_path_exists(
-    "//ListNode//IntNode",
+    "//MulNode",
 )
 def mult_empty_list():
     """
@@ -289,3 +289,27 @@ def mult_empty_list():
     []
     """
     return 5 * [] * 100
+
+
+@cython.test_fail_if_path_exists(
+    "//MulNode",
+    "//ListNode//IntNode",
+)
+def neg_mult_list():
+    """
+    >>> neg_mult_list()
+    []
+    """
+    return -5 * [1, 2] * -100
+
+
+@cython.test_fail_if_path_exists(
+    "//MulNode",
+    "//ListNode//IntNode",
+)
+def zero_mult_list():
+    """
+    >>> zero_mult_list()
+    []
+    """
+    return 0 * [1, 2] * 0