fix type coercion in cascaded comparisons
authorStefan Behnel <stefan_ml@behnel.de>
Fri, 9 Nov 2012 21:52:09 +0000 (22:52 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Fri, 9 Nov 2012 21:52:09 +0000 (22:52 +0100)
Cython/Compiler/ExprNodes.py
tests/run/inop.pyx

index 87a531a..ef8b800 100755 (executable)
@@ -8919,9 +8919,10 @@ class PrimaryCmpNode(ExprNode, CmpNode):
     #  Instead, we override all the framework methods
     #  which use it.
 
-    child_attrs = ['operand1', 'operand2', 'cascade']
+    child_attrs = ['operand1', 'operand2', 'coerced_operand2', 'cascade']
 
     cascade = None
+    coerced_operand2 = None
     is_memslice_nonecheck = False
 
     def infer_type(self, env):
@@ -8999,9 +9000,11 @@ class PrimaryCmpNode(ExprNode, CmpNode):
             self.coerce_operands_to(common_type, env)
 
         if self.cascade:
-            self.operand2 = self.cascade.optimise_comparison(
-                self.operand2.coerce_to_simple(env), env)
+            self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
+            operand2 = self.cascade.optimise_comparison(self.operand2, env)
+            if operand2 is not self.operand2:
+                self.coerced_operand2 = operand2
         if self.is_python_result():
             self.type = PyrexTypes.py_object_type
         else:
@@ -9105,8 +9108,9 @@ class PrimaryCmpNode(ExprNode, CmpNode):
             self.generate_operation_code(code, self.result(),
                 self.operand1, self.operator, self.operand2)
             if self.cascade:
-                self.cascade.generate_evaluation_code(code,
-                    self.result(), self.operand2)
+                self.cascade.generate_evaluation_code(
+                    code, self.result(), self.coerced_operand2 or self.operand2,
+                    needs_evaluation=self.coerced_operand2 is not None)
             self.operand1.generate_disposal_code(code)
             self.operand1.free_temps(code)
             self.operand2.generate_disposal_code(code)
@@ -9141,9 +9145,10 @@ class CascadedCmpNode(Node, CmpNode):
     #  operand2      ExprNode
     #  cascade       CascadedCmpNode
 
-    child_attrs = ['operand2', 'cascade']
+    child_attrs = ['operand2', 'coerced_operand2', 'cascade']
 
     cascade = None
+    coerced_operand2 = None
     constant_result = constant_value_not_set # FIXME: where to calculate this?
 
     def infer_type(self, env):
@@ -9170,7 +9175,9 @@ class CascadedCmpNode(Node, CmpNode):
             if not operand1.type.is_pyobject:
                 operand1 = operand1.coerce_to_pyobject(env)
         if self.cascade:
-            self.operand2 = self.cascade.optimise_comparison(self.operand2, env)
+            operand2 = self.cascade.optimise_comparison(self.operand2, env)
+            if operand2 is not self.operand2:
+                self.coerced_operand2 = operand2
         return operand1
 
     def coerce_operands_to_pyobjects(self, env):
@@ -9186,18 +9193,24 @@ class CascadedCmpNode(Node, CmpNode):
             self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
 
-    def generate_evaluation_code(self, code, result, operand1):
+    def generate_evaluation_code(self, code, result, operand1, needs_evaluation=False):
         if self.type.is_pyobject:
             code.putln("if (__Pyx_PyObject_IsTrue(%s)) {" % result)
             code.put_decref(result, self.type)
         else:
             code.putln("if (%s) {" % result)
+        if needs_evaluation:
+            operand1.generate_evaluation_code(code)
         self.operand2.generate_evaluation_code(code)
         self.generate_operation_code(code, result,
             operand1, self.operator, self.operand2)
         if self.cascade:
             self.cascade.generate_evaluation_code(
-                code, result, self.operand2)
+                code, result, self.coerced_operand2 or self.operand2,
+                needs_evaluation=self.coerced_operand2 is not None)
+        if needs_evaluation:
+            operand1.generate_disposal_code(code)
+            operand1.free_temps(code)
         # Cascaded cmp result is always temp
         self.operand2.generate_disposal_code(code)
         self.operand2.free_temps(code)
index da3ac79..481c60c 100644 (file)
@@ -376,3 +376,29 @@ def test_inop_cascaded(x):
     False
     """
     return 1 != x in [2]
+
+def test_inop_cascaded_one():
+    """
+    >>> test_inop_cascaded_one()
+    False
+    """
+    # copied from CPython's test_grammar.py
+    return 1 < 1 > 1 == 1 >= 1 <= 1 != 1 in 1 not in 1 is 1 is not 1
+
+def test_inop_cascaded_int_orig(int x):
+    """
+    >>> test_inop_cascaded_int_orig(1)
+    False
+    """
+    return 1 < 1 > 1 == 1 >= 1 <= 1 != x in 1 not in 1 is 1 is not 1
+
+def test_inop_cascaded_int(int x):
+    """
+    >>> test_inop_cascaded_int(1)
+    False
+    >>> test_inop_cascaded_int(2)
+    True
+    >>> test_inop_cascaded_int(3)
+    False
+    """
+    return 1 != x in [1,2]