improve type inference for string %/+/* operations and use more direct C-API calls...
authorStefan Behnel <stefan_ml@behnel.de>
Fri, 6 Dec 2013 14:22:52 +0000 (15:22 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Fri, 6 Dec 2013 14:22:52 +0000 (15:22 +0100)
Cython/Compiler/ExprNodes.py
Cython/Utility/ModuleSetupCode.c
tests/run/unicodemethods.pyx

index c13cac2..3fa4783 100644 (file)
@@ -28,8 +28,8 @@ import PyrexTypes
 from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \
     unspecified_type
 import TypeSlots
-from Builtin import list_type, tuple_type, set_type, dict_type, \
-     unicode_type, str_type, bytes_type, bytearray_type, type_type
+from Builtin import list_type, tuple_type, set_type, dict_type, type_type, \
+     unicode_type, str_type, bytes_type, bytearray_type, basestring_type
 import Builtin
 import Symtab
 from Cython import Utils
@@ -8792,30 +8792,20 @@ class BinopNode(ExprNode):
                 type1 = Builtin.bytes_type
             elif type1.is_pyunicode_ptr:
                 type1 = Builtin.unicode_type
-            elif self.operator == '%' \
-                     and type1 in (Builtin.str_type, Builtin.unicode_type):
-                # note that  b'%s' % b'abc'  doesn't work in Py3
-                return type1
-            if type1.is_builtin_type:
-                if type1 is type2:
-                    if self.operator in '**%+|&^':
-                        # FIXME: at least these operators should be safe - others?
-                        return type1
-                elif self.operator == '*':
-                    if type1 in (Builtin.bytes_type, Builtin.str_type, Builtin.unicode_type):
-                        return type1
-                    # multiplication of containers/numbers with an
-                    # integer value always (?) returns the same type
-                    if type2.is_int:
-                        return type1
-            elif type2.is_builtin_type and type1.is_int and self.operator == '*':
-                # multiplication of containers/numbers with an
-                # integer value always (?) returns the same type
-                return type2
+            if type1.is_builtin_type or type2.is_builtin_type:
+                if type1 is type2 and self.operator in '**%+|&^':
+                    # FIXME: at least these operators should be safe - others?
+                    return type1
+                result_type = self.infer_builtin_types_operation(type1, type2)
+                if result_type is not None:
+                    return result_type
             return py_object_type
         else:
             return self.compute_c_result_type(type1, type2)
 
+    def infer_builtin_types_operation(self, type1, type2):
+        return None
+
     def nogil_check(self, env):
         if self.is_py_operation():
             self.gil_error()
@@ -9019,14 +9009,15 @@ class NumBinopNode(BinopNode):
         "%":        "PyNumber_Remainder",
         "**":       "PyNumber_Power"
     }
-    
+
     overflow_op_names = {
-       "+":  "add",
-       "-":  "sub",
-       "*":  "mul",
-       "<<":  "lshift",
+        "+":  "add",
+        "-":  "sub",
+        "*":  "mul",
+        "<<":  "lshift",
     }
 
+
 class IntBinopNode(NumBinopNode):
     #  Binary operation taking integer arguments.
 
@@ -9045,6 +9036,15 @@ class AddNode(NumBinopNode):
         else:
             return NumBinopNode.is_py_operation_types(self, type1, type2)
 
+    def infer_builtin_types_operation(self, type1, type2):
+        # b'abc' + 'abc' raises an exception in Py3,
+        # so we can safely infer the Py2 type for bytes here
+        string_types = [bytes_type, str_type, basestring_type, unicode_type]  # Py2.4 lacks tuple.index()
+        if type1 in string_types and type2 in string_types:
+            return string_types[max(string_types.index(type1),
+                                    string_types.index(type2))]
+        return None
+
     def compute_c_result_type(self, type1, type2):
         #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ###
         if (type1.is_ptr or type1.is_array) and (type2.is_int or type2.is_enum):
@@ -9055,6 +9055,16 @@ class AddNode(NumBinopNode):
             return NumBinopNode.compute_c_result_type(
                 self, type1, type2)
 
+    def py_operation_function(self):
+        type1, type2 = self.operand1.type, self.operand2.type
+        if type1 is unicode_type or type2 is unicode_type:
+            if type1.is_builtin_type and type2.is_builtin_type:
+                if self.operand1.may_be_none() or self.operand2.may_be_none():
+                    return '__Pyx_PyUnicode_Concat'
+                else:
+                    return 'PyUnicode_Concat'
+        return super(AddNode, self).py_operation_function()
+
 
 class SubNode(NumBinopNode):
     #  '-' operator.
@@ -9073,12 +9083,28 @@ class MulNode(NumBinopNode):
     #  '*' operator.
 
     def is_py_operation_types(self, type1, type2):
-        if (type1.is_string and type2.is_int) \
-            or (type2.is_string and type1.is_int):
-                return 1
+        if ((type1.is_string and type2.is_int) or
+                (type2.is_string and type1.is_int)):
+            return 1
         else:
             return NumBinopNode.is_py_operation_types(self, type1, type2)
 
+    def infer_builtin_types_operation(self, type1, type2):
+        # let's assume that whatever builtin type you multiply a string with
+        # will either return a string of the same type or fail with an exception
+        string_types = (bytes_type, str_type, basestring_type, unicode_type)
+        if type1 in string_types and type2.is_builtin_type:
+            return type1
+        if type2 in string_types and type1.is_builtin_type:
+            return type2
+        # multiplication of containers/numbers with an integer value
+        # always (?) returns the same type
+        if type1.is_int:
+            return type2
+        if type2.is_int:
+            return type1
+        return None
+
 
 class DivNode(NumBinopNode):
     #  '/' or '//' operator.
@@ -9218,9 +9244,9 @@ class DivNode(NumBinopNode):
             return "(%s / %s)" % (op1, op2)
         else:
             return "__Pyx_div_%s(%s, %s)" % (
-                    self.type.specialization_name(),
-                    self.operand1.result(),
-                    self.operand2.result())
+                self.type.specialization_name(),
+                self.operand1.result(),
+                self.operand2.result())
 
 
 class ModNode(DivNode):
@@ -9228,8 +9254,25 @@ class ModNode(DivNode):
 
     def is_py_operation_types(self, type1, type2):
         return (type1.is_string
-            or type2.is_string
-            or NumBinopNode.is_py_operation_types(self, type1, type2))
+                or type2.is_string
+                or NumBinopNode.is_py_operation_types(self, type1, type2))
+
+    def infer_builtin_types_operation(self, type1, type2):
+        # b'%s' % xyz  raises an exception in Py3, so it's safe to infer the type for Py2
+        if type1 is unicode_type:
+            # None + xyz  may be implemented by RHS
+            if type2.is_builtin_type or not self.operand1.may_be_none():
+                return type1
+        elif type1 in (bytes_type, str_type, basestring_type):
+            if type2 is unicode_type:
+                return type2
+            elif type2.is_numeric:
+                return type1
+            elif type1 is bytes_type and not type2.is_builtin_type:
+                return None   # RHS might implement '% operator differently in Py3
+            else:
+                return basestring_type  # either str or unicode, can't tell
+        return None
 
     def zero_division_message(self):
         if self.type.is_int:
@@ -9275,6 +9318,15 @@ class ModNode(DivNode):
                     self.operand1.result(),
                     self.operand2.result())
 
+    def py_operation_function(self):
+        if self.operand1.type is unicode_type and self.operand2.type.is_builtin_type:
+            if self.operand1.may_be_none():
+                return '__Pyx_PyUnicode_Format'
+            else:
+                return 'PyUnicode_Format'
+        return super(ModNode, self).py_operation_function()
+
+
 class PowNode(NumBinopNode):
     #  '**' operator.
 
index 1fe66ea..3bb72cd 100644 (file)
   #define __Pyx_PyUnicode_READ(k, d, i)   ((k=k), (Py_UCS4)(((Py_UNICODE*)d)[i]))
 #endif
 
+#define __Pyx_PyUnicode_Format(a, b)  ((unlikely((a) == Py_None)) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b))
+#define __Pyx_PyUnicode_Concat(a, b)  ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ? \
+    PyNumber_Add(a, b) : PyUnicode_Concat(a, b))
+
 #if PY_MAJOR_VERSION >= 3
   #define PyBaseString_Type            PyUnicode_Type
   #define PyStringObject               PyUnicodeObject
index 72606e7..c466d54 100644 (file)
@@ -8,6 +8,9 @@ PY_VERSION = sys.version_info
 
 text = u'ab jd  sdflk as sa  sadas asdas fsdf '
 sep = u'  '
+format1 = u'abc%sdef'
+format2 = u'abc%sdef%sghi'
+unicode_sa = u'sa'
 
 multiline_text = u'''\
 ab jd
@@ -383,6 +386,122 @@ def in_test(unicode s, substring):
     return substring in s
 
 
+# unicode.__concat__(s, suffix)
+
+def concat_any(unicode s, suffix):
+    """
+    >>> concat(text, 'sa') == text + 'sa'  or  concat(text, 'sa')
+    True
+    >>> concat(None, 'sa')   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    >>> concat(text, None)   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    >>> class RAdd(object):
+    ...     def __radd__(self, other):
+    ...         return 123
+    >>> concat(None, 'sa')   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    """
+    assert cython.typeof(s + suffix) == 'Python object', cython.typeof(s + suffix)
+    return s + suffix
+
+
+def concat(unicode s, str suffix):
+    """
+    >>> concat(text, 'sa') == text + 'sa'  or  concat(text, 'sa')
+    True
+    >>> concat(None, 'sa')   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    >>> concat(text, None)   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    >>> class RAdd(object):
+    ...     def __radd__(self, other):
+    ...         return 123
+    >>> concat(None, 'sa')   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...
+    """
+    assert cython.typeof(s + object()) == 'Python object', cython.typeof(s + object())
+    assert cython.typeof(s + suffix) == 'unicode object', cython.typeof(s + suffix)
+    return s + suffix
+
+
+def concat_literal_str(str suffix):
+    """
+    >>> concat_literal_str('sa') == 'abcsa'  or  concat_literal_str('sa')
+    True
+    >>> concat_literal_str(None)  # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...NoneType...
+    """
+    assert cython.typeof(u'abc' + object()) == 'Python object', cython.typeof(u'abc' + object())
+    assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix)
+    return u'abc' + suffix
+
+
+def concat_literal_unicode(unicode suffix):
+    """
+    >>> concat_literal_unicode(unicode_sa) == 'abcsa'  or  concat_literal_unicode(unicode_sa)
+    True
+    >>> concat_literal_unicode(None)  # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...NoneType...
+    """
+    assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix)
+    return u'abc' + suffix
+
+
+# unicode.__mod__(format, values)
+
+def mod_format(unicode s, values):
+    """
+    >>> mod_format(format1, 'sa') == 'abcsadef'  or  mod_format(format1, 'sa')
+    True
+    >>> mod_format(format2, ('XYZ', 'ABC')) == 'abcXYZdefABCghi'  or  mod_format(format2, ('XYZ', 'ABC'))
+    True
+    >>> mod_format(None, 'sa')   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: unsupported operand type(s) for %: 'NoneType' and 'str'
+    >>> class RMod(object):
+    ...     def __rmod__(self, other):
+    ...         return 123
+    >>> mod_format(None, RMod())
+    123
+    """
+    assert cython.typeof(s % values) == 'Python object', cython.typeof(s % values)
+    return s % values
+
+
+def mod_format_literal(values):
+    """
+    >>> mod_format_literal('sa') == 'abcsadef'  or  mod_format(format1, 'sa')
+    True
+    >>> mod_format_literal(('sa',)) == 'abcsadef'  or  mod_format(format1, ('sa',))
+    True
+    >>> mod_format_literal(['sa']) == "abc['sa']def"  or  mod_format(format1, ['sa'])
+    True
+    """
+    assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values)
+    return u'abc%sdef' % values
+
+
+def mod_format_tuple(*values):
+    """
+    >>> mod_format_tuple('sa') == 'abcsadef'  or  mod_format(format1, 'sa')
+    True
+    >>> mod_format_tuple()
+    Traceback (most recent call last):
+    TypeError: not enough arguments for format string
+    """
+    assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values)
+    return u'abc%sdef' % values
+
+
 # unicode.find(s, sub, [start, [end]])
 
 @cython.test_fail_if_path_exists(