properly propagate string comparison optimisation into cascaded comparisons
authorStefan Behnel <stefan_ml@behnel.de>
Thu, 9 Aug 2012 10:05:22 +0000 (12:05 +0200)
committerStefan Behnel <stefan_ml@behnel.de>
Thu, 9 Aug 2012 10:05:22 +0000 (12:05 +0200)
Cython/Compiler/ExprNodes.py
tests/run/string_comparison.pyx [new file with mode: 0644]

index f89939d..f5eceeb 100755 (executable)
@@ -8696,9 +8696,9 @@ class CmpNode(object):
             return (container_type.is_ptr or container_type.is_array) \
                 and not container_type.is_string
 
-    def find_special_bool_compare_function(self, env):
+    def find_special_bool_compare_function(self, env, operand1):
         if self.operator in ('==', '!='):
-            type1, type2 = self.operand1.type, self.operand2.type
+            type1, type2 = operand1.type, self.operand2.type
             if type1.is_pyobject and type2.is_pyobject:
                 if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
                     env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c"))
@@ -8901,7 +8901,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
                     self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
                 common_type = py_object_type
                 self.is_pycmp = True
-        elif self.find_special_bool_compare_function(env):
+        elif self.find_special_bool_compare_function(env, self.operand1):
             common_type = None # if coercion needed, the method call above has already done it
             self.is_pycmp = False # result is bint
             self.is_temp = True # must check for error return
@@ -8916,6 +8916,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
 
         if self.cascade:
             self.operand2 = self.operand2.coerce_to_simple(env)
+            self.cascade.optimise_comparison(env, self.operand2)
             self.cascade.coerce_cascaded_operands_to_temp(env)
         if self.is_python_result():
             self.type = PyrexTypes.py_object_type
@@ -9079,6 +9080,11 @@ class CascadedCmpNode(Node, CmpNode):
     def has_python_operands(self):
         return self.operand2.type.is_pyobject
 
+    def optimise_comparison(self, env, operand1):
+        self.find_special_bool_compare_function(env, operand1)
+        if self.cascade:
+            self.cascade.optimise_comparison(env, self.operand2)
+
     def coerce_operands_to_pyobjects(self, env):
         self.operand2 = self.operand2.coerce_to_pyobject(env)
         if self.operand2.type is dict_type and self.operator in ('in', 'not_in'):
diff --git a/tests/run/string_comparison.pyx b/tests/run/string_comparison.pyx
new file mode 100644 (file)
index 0000000..590e77b
--- /dev/null
@@ -0,0 +1,193 @@
+
+bstring1 = b"abcdefg"
+bstring2 = b"1234567"
+
+string1 = "abcdefg"
+string2 = "1234567"
+
+ustring1 = u"abcdefg"
+ustring2 = u"1234567"
+
+# unicode
+
+def unicode_eq(unicode s1, unicode s2):
+    """
+    >>> unicode_eq(ustring1, ustring1)
+    True
+    >>> unicode_eq(ustring1+ustring2, ustring1+ustring2)
+    True
+    >>> unicode_eq(ustring1, ustring2)
+    False
+    """
+    return s1 == s2
+
+def unicode_neq(unicode s1, unicode s2):
+    """
+    >>> unicode_neq(ustring1, ustring1)
+    False
+    >>> unicode_neq(ustring1+ustring2, ustring1+ustring2)
+    False
+    >>> unicode_neq(ustring1, ustring2)
+    True
+    """
+    return s1 != s2
+
+def unicode_literal_eq(unicode s):
+    """
+    >>> unicode_literal_eq(ustring1)
+    True
+    >>> unicode_literal_eq((ustring1+ustring2)[:len(ustring1)])
+    True
+    >>> unicode_literal_eq(ustring2)
+    False
+    """
+    return s == u"abcdefg"
+
+def unicode_literal_neq(unicode s):
+    """
+    >>> unicode_literal_neq(ustring1)
+    False
+    >>> unicode_literal_neq((ustring1+ustring2)[:len(ustring1)])
+    False
+    >>> unicode_literal_neq(ustring2)
+    True
+    """
+    return s != u"abcdefg"
+
+def unicode_cascade(unicode s1, unicode s2):
+    """
+    >>> unicode_cascade(ustring1, ustring1)
+    True
+    >>> unicode_cascade(ustring1, (ustring1+ustring2)[:len(ustring1)])
+    True
+    >>> unicode_cascade(ustring1, ustring2)
+    False
+    """
+    return s1 == s2 == u"abcdefg"
+
+''' # NOTE: currently crashes
+def unicode_cascade_untyped_end(unicode s1, unicode s2):
+    """
+    >>> unicode_cascade_untyped_end(ustring1, ustring1)
+    True
+    >>> unicode_cascade_untyped_end(ustring1, (ustring1+ustring2)[:len(ustring1)])
+    True
+    >>> unicode_cascade_untyped_end(ustring1, ustring2)
+    False
+    """
+    return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
+'''
+
+# str
+
+def str_eq(str s1, str s2):
+    """
+    >>> str_eq(string1, string1)
+    True
+    >>> str_eq(string1+string2, string1+string2)
+    True
+    >>> str_eq(string1, string2)
+    False
+    """
+    return s1 == s2
+
+def str_neq(str s1, str s2):
+    """
+    >>> str_neq(string1, string1)
+    False
+    >>> str_neq(string1+string2, string1+string2)
+    False
+    >>> str_neq(string1, string2)
+    True
+    """
+    return s1 != s2
+
+def str_literal_eq(str s):
+    """
+    >>> str_literal_eq(string1)
+    True
+    >>> str_literal_eq((string1+string2)[:len(string1)])
+    True
+    >>> str_literal_eq(string2)
+    False
+    """
+    return s == "abcdefg"
+
+def str_literal_neq(str s):
+    """
+    >>> str_literal_neq(string1)
+    False
+    >>> str_literal_neq((string1+string2)[:len(string1)])
+    False
+    >>> str_literal_neq(string2)
+    True
+    """
+    return s != "abcdefg"
+
+def str_cascade(str s1, str s2):
+    """
+    >>> str_cascade(string1, string1)
+    True
+    >>> str_cascade(string1, (string1+string2)[:len(string1)])
+    True
+    >>> str_cascade(string1, string2)
+    False
+    """
+    return s1 == s2 == "abcdefg"
+
+# bytes
+
+def bytes_eq(bytes s1, bytes s2):
+    """
+    >>> bytes_eq(bstring1, bstring1)
+    True
+    >>> bytes_eq(bstring1+bstring2, bstring1+bstring2)
+    True
+    >>> bytes_eq(bstring1, bstring2)
+    False
+    """
+    return s1 == s2
+
+def bytes_neq(bytes s1, bytes s2):
+    """
+    >>> bytes_neq(bstring1, bstring1)
+    False
+    >>> bytes_neq(bstring1+bstring2, bstring1+bstring2)
+    False
+    >>> bytes_neq(bstring1, bstring2)
+    True
+    """
+    return s1 != s2
+
+def bytes_literal_eq(bytes s):
+    """
+    >>> bytes_literal_eq(bstring1)
+    True
+    >>> bytes_literal_eq((bstring1+bstring2)[:len(bstring1)])
+    True
+    >>> bytes_literal_eq(bstring2)
+    False
+    """
+    return s == b"abcdefg"
+
+def bytes_literal_neq(bytes s):
+    """
+    >>> bytes_literal_neq(bstring1)
+    False
+    >>> bytes_literal_neq((bstring1+bstring2)[:len(bstring1)])
+    False
+    >>> bytes_literal_neq(bstring2)
+    True
+    """
+    return s != b"abcdefg"
+
+def bytes_cascade(bytes s1, bytes s2):
+    """
+    >>> bytes_cascade(bstring1, bstring1)
+    True
+    >>> bytes_cascade(bstring1, (bstring1+bstring2)[:len(bstring1)])
+    True
+    >>> bytes_cascade(bstring1, bstring2)
+    False
+    """
+    return s1 == s2 == b"abcdefg"