refactor constant string slicing and guard it against platform specific unicode strin...
authorStefan Behnel <stefan_ml@behnel.de>
Sat, 23 Feb 2013 13:54:41 +0000 (14:54 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Sat, 23 Feb 2013 13:54:41 +0000 (14:54 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/constant_folding_cy.pyx

index 732afcf..9f73c59 100755 (executable)
@@ -1071,6 +1071,12 @@ class BytesNode(ConstNode):
     def calculate_constant_result(self):
         self.constant_result = self.value
 
+    def as_sliced_node(self, start, stop, step=None):
+        value = StringEncoding.BytesLiteral(self.value[start:stop:step])
+        value.encoding = self.value.encoding
+        return BytesNode(
+            self.pos, value=value, constant_result=value)
+
     def compile_time_value(self, denv):
         return self.value
 
@@ -1155,6 +1161,22 @@ class UnicodeNode(PyConstNode):
     def calculate_constant_result(self):
         self.constant_result = self.value
 
+    def as_sliced_node(self, start, stop, step=None):
+        if _string_contains_surrogates(self.value[:stop]):
+            # this is unsafe as it may give different results in different runtimes
+            return None
+        value = StringEncoding.EncodedString(self.value[start:stop:step])
+        value.encoding = self.value.encoding
+        if self.bytes_value is not None:
+            bytes_value = StringEncoding.BytesLiteral(
+                self.bytes_value[start:stop:step])
+            bytes_value.encoding = self.bytes_value.encoding
+        else:
+            bytes_value = None
+        return UnicodeNode(
+            self.pos, value=value, bytes_value=bytes_value,
+            constant_result=value)
+
     def coerce_to(self, dst_type, env):
         if dst_type is self.type:
             pass
@@ -1181,21 +1203,7 @@ class UnicodeNode(PyConstNode):
             ##     and (0xDC00 <= self.value[1] <= 0xDFFF))
 
     def contains_surrogates(self):
-        # Check if the unicode string contains surrogate code points
-        # on a CPython platform with wide (UCS-4) or narrow (UTF-16)
-        # Unicode, i.e. characters that would be spelled as two
-        # separate code units on a narrow platform.
-        for c in map(ord, self.value):
-            if c > 65535: # can only happen on wide platforms
-                return True
-            # We only look for the first code unit (D800-DBFF) of a
-            # surrogate pair - if we find one, the other one
-            # (DC00-DFFF) is likely there, too.  If we don't find it,
-            # any second code unit cannot make for a surrogate pair by
-            # itself.
-            if 0xD800 <= c <= 0xDBFF:
-                return True
-        return False
+        return _string_contains_surrogates(self.value)
 
     def generate_evaluation_code(self, code):
         self.result_code = code.get_py_string_const(self.value)
@@ -1223,6 +1231,21 @@ class StringNode(PyConstNode):
     def calculate_constant_result(self):
         self.constant_result = self.value
 
+    def as_sliced_node(self, start, stop, step=None):
+        value = type(self.value)(self.value[start:stop:step])
+        value.encoding = self.value.encoding
+        if self.unicode_value is not None:
+            if _string_contains_surrogates(self.unicode_value[:stop]):
+                # this is unsafe as it may give different results in different runtimes
+                return None
+            unicode_value = StringEncoding.EncodedString(
+                self.unicode_value[start:stop:step])
+        else:
+            unicode_value = None
+        return StringNode(
+            self.pos, value=value, unicode_value=unicode_value,
+            constant_result=value, is_identifier=self.is_identifier)
+
     def coerce_to(self, dst_type, env):
         if dst_type is not py_object_type and not str_type.subtype_of(dst_type):
 #            if dst_type is Builtin.bytes_type:
@@ -1257,6 +1280,26 @@ class IdentifierStringNode(StringNode):
     is_identifier = True
 
 
+def _string_contains_surrogates(ustring):
+    """
+    Check if the unicode string contains surrogate code points
+    on a CPython platform with wide (UCS-4) or narrow (UTF-16)
+    Unicode, i.e. characters that would be spelled as two
+    separate code units on a narrow platform.
+    """
+    for c in map(ord, ustring):
+        if c > 65535: # can only happen on wide platforms
+            return True
+            # We only look for the first code unit (D800-DBFF) of a
+        # surrogate pair - if we find one, the other one
+        # (DC00-DFFF) is likely there, too.  If we don't find it,
+        # any second code unit cannot make for a surrogate pair by
+        # itself.
+        if 0xD800 <= c <= 0xDBFF:
+            return True
+    return False
+
+
 class ImagNode(AtomicExprNode):
     #  Imaginary number literal
     #
index dd4b893..42a7863 100644 (file)
@@ -3206,18 +3206,9 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
                 base.args = base.args[start:stop]
                 return base
             elif base.is_string_literal:
-                value = type(base.value)(node.constant_result)
-                value.encoding = base.value.encoding
-                base.value = value
-                if isinstance(base, ExprNodes.StringNode):
-                    if base.unicode_value is not None:
-                        base.unicode_value = EncodedString(
-                            base.unicode_value[start:stop])
-                elif isinstance(base, ExprNodes.UnicodeNode):
-                    if base.bytes_value is not None:
-                        base.bytes_value = BytesLiteral(
-                            base.bytes_value[start:stop])
-                return base
+                base = base.as_sliced_node(start, stop)
+                if base is not None:
+                    return base
         return node
 
     def visit_ForInStatNode(self, node):
index 7c9e35c..3700a05 100644 (file)
@@ -7,6 +7,7 @@ cimport cython
 
 bstring = b'abc\xE9def'
 ustring = u'abc\xE9def'
+surrogates_ustring = u'abc\U00010000def'
 
 
 @cython.test_fail_if_path_exists(
@@ -53,3 +54,29 @@ def unicode_slicing2():
     str3 = u'abc\xE9def'[2:4]
 
     return str0, str1, str2, str3
+
+
+@cython.test_assert_path_exists(
+    "//SliceIndexNode",
+    )
+def unicode_slicing_unsafe_surrogates2():
+    """
+    >>> unicode_slicing_unsafe_surrogates2() == surrogates_ustring[2:]
+    True
+    """
+    ustring = u'abc\U00010000def'[2:]
+    return ustring
+
+
+@cython.test_fail_if_path_exists(
+    "//SliceIndexNode",
+    )
+def unicode_slicing_safe_surrogates2():
+    """
+    >>> unicode_slicing_safe_surrogates2() == surrogates_ustring[:2]
+    True
+    >>> print(unicode_slicing_safe_surrogates2())
+    ab
+    """
+    ustring = u'abc\U00010000def'[:2]
+    return ustring