From 1da05d6e1e930b365491eb81c1b8b6bd6b8f158d Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Fri, 6 Dec 2013 15:22:52 +0100 Subject: [PATCH] improve type inference for string %/+/* operations and use more direct C-API calls for these unicode operations --- Cython/Compiler/ExprNodes.py | 122 ++++++++++++++++++++++++++++----------- Cython/Utility/ModuleSetupCode.c | 4 ++ tests/run/unicodemethods.pyx | 119 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+), 35 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index c13cac2..3fa4783 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -28,8 +28,8 @@ import PyrexTypes from PyrexTypes import py_object_type, c_long_type, typecast, error_type, \ unspecified_type import TypeSlots -from Builtin import list_type, tuple_type, set_type, dict_type, \ - unicode_type, str_type, bytes_type, bytearray_type, type_type +from Builtin import list_type, tuple_type, set_type, dict_type, type_type, \ + unicode_type, str_type, bytes_type, bytearray_type, basestring_type import Builtin import Symtab from Cython import Utils @@ -8792,30 +8792,20 @@ class BinopNode(ExprNode): type1 = Builtin.bytes_type elif type1.is_pyunicode_ptr: type1 = Builtin.unicode_type - elif self.operator == '%' \ - and type1 in (Builtin.str_type, Builtin.unicode_type): - # note that b'%s' % b'abc' doesn't work in Py3 - return type1 - if type1.is_builtin_type: - if type1 is type2: - if self.operator in '**%+|&^': - # FIXME: at least these operators should be safe - others? - return type1 - elif self.operator == '*': - if type1 in (Builtin.bytes_type, Builtin.str_type, Builtin.unicode_type): - return type1 - # multiplication of containers/numbers with an - # integer value always (?) returns the same type - if type2.is_int: - return type1 - elif type2.is_builtin_type and type1.is_int and self.operator == '*': - # multiplication of containers/numbers with an - # integer value always (?) returns the same type - return type2 + if type1.is_builtin_type or type2.is_builtin_type: + if type1 is type2 and self.operator in '**%+|&^': + # FIXME: at least these operators should be safe - others? + return type1 + result_type = self.infer_builtin_types_operation(type1, type2) + if result_type is not None: + return result_type return py_object_type else: return self.compute_c_result_type(type1, type2) + def infer_builtin_types_operation(self, type1, type2): + return None + def nogil_check(self, env): if self.is_py_operation(): self.gil_error() @@ -9019,14 +9009,15 @@ class NumBinopNode(BinopNode): "%": "PyNumber_Remainder", "**": "PyNumber_Power" } - + overflow_op_names = { - "+": "add", - "-": "sub", - "*": "mul", - "<<": "lshift", + "+": "add", + "-": "sub", + "*": "mul", + "<<": "lshift", } + class IntBinopNode(NumBinopNode): # Binary operation taking integer arguments. @@ -9045,6 +9036,15 @@ class AddNode(NumBinopNode): else: return NumBinopNode.is_py_operation_types(self, type1, type2) + def infer_builtin_types_operation(self, type1, type2): + # b'abc' + 'abc' raises an exception in Py3, + # so we can safely infer the Py2 type for bytes here + string_types = [bytes_type, str_type, basestring_type, unicode_type] # Py2.4 lacks tuple.index() + if type1 in string_types and type2 in string_types: + return string_types[max(string_types.index(type1), + string_types.index(type2))] + return None + def compute_c_result_type(self, type1, type2): #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ### if (type1.is_ptr or type1.is_array) and (type2.is_int or type2.is_enum): @@ -9055,6 +9055,16 @@ class AddNode(NumBinopNode): return NumBinopNode.compute_c_result_type( self, type1, type2) + def py_operation_function(self): + type1, type2 = self.operand1.type, self.operand2.type + if type1 is unicode_type or type2 is unicode_type: + if type1.is_builtin_type and type2.is_builtin_type: + if self.operand1.may_be_none() or self.operand2.may_be_none(): + return '__Pyx_PyUnicode_Concat' + else: + return 'PyUnicode_Concat' + return super(AddNode, self).py_operation_function() + class SubNode(NumBinopNode): # '-' operator. @@ -9073,12 +9083,28 @@ class MulNode(NumBinopNode): # '*' operator. def is_py_operation_types(self, type1, type2): - if (type1.is_string and type2.is_int) \ - or (type2.is_string and type1.is_int): - return 1 + if ((type1.is_string and type2.is_int) or + (type2.is_string and type1.is_int)): + return 1 else: return NumBinopNode.is_py_operation_types(self, type1, type2) + def infer_builtin_types_operation(self, type1, type2): + # let's assume that whatever builtin type you multiply a string with + # will either return a string of the same type or fail with an exception + string_types = (bytes_type, str_type, basestring_type, unicode_type) + if type1 in string_types and type2.is_builtin_type: + return type1 + if type2 in string_types and type1.is_builtin_type: + return type2 + # multiplication of containers/numbers with an integer value + # always (?) returns the same type + if type1.is_int: + return type2 + if type2.is_int: + return type1 + return None + class DivNode(NumBinopNode): # '/' or '//' operator. @@ -9218,9 +9244,9 @@ class DivNode(NumBinopNode): return "(%s / %s)" % (op1, op2) else: return "__Pyx_div_%s(%s, %s)" % ( - self.type.specialization_name(), - self.operand1.result(), - self.operand2.result()) + self.type.specialization_name(), + self.operand1.result(), + self.operand2.result()) class ModNode(DivNode): @@ -9228,8 +9254,25 @@ class ModNode(DivNode): def is_py_operation_types(self, type1, type2): return (type1.is_string - or type2.is_string - or NumBinopNode.is_py_operation_types(self, type1, type2)) + or type2.is_string + or NumBinopNode.is_py_operation_types(self, type1, type2)) + + def infer_builtin_types_operation(self, type1, type2): + # b'%s' % xyz raises an exception in Py3, so it's safe to infer the type for Py2 + if type1 is unicode_type: + # None + xyz may be implemented by RHS + if type2.is_builtin_type or not self.operand1.may_be_none(): + return type1 + elif type1 in (bytes_type, str_type, basestring_type): + if type2 is unicode_type: + return type2 + elif type2.is_numeric: + return type1 + elif type1 is bytes_type and not type2.is_builtin_type: + return None # RHS might implement '% operator differently in Py3 + else: + return basestring_type # either str or unicode, can't tell + return None def zero_division_message(self): if self.type.is_int: @@ -9275,6 +9318,15 @@ class ModNode(DivNode): self.operand1.result(), self.operand2.result()) + def py_operation_function(self): + if self.operand1.type is unicode_type and self.operand2.type.is_builtin_type: + if self.operand1.may_be_none(): + return '__Pyx_PyUnicode_Format' + else: + return 'PyUnicode_Format' + return super(ModNode, self).py_operation_function() + + class PowNode(NumBinopNode): # '**' operator. diff --git a/Cython/Utility/ModuleSetupCode.c b/Cython/Utility/ModuleSetupCode.c index 1fe66ea..3bb72cd 100644 --- a/Cython/Utility/ModuleSetupCode.c +++ b/Cython/Utility/ModuleSetupCode.c @@ -157,6 +157,10 @@ #define __Pyx_PyUnicode_READ(k, d, i) ((k=k), (Py_UCS4)(((Py_UNICODE*)d)[i])) #endif +#define __Pyx_PyUnicode_Format(a, b) ((unlikely((a) == Py_None)) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) +#define __Pyx_PyUnicode_Concat(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ? \ + PyNumber_Add(a, b) : PyUnicode_Concat(a, b)) + #if PY_MAJOR_VERSION >= 3 #define PyBaseString_Type PyUnicode_Type #define PyStringObject PyUnicodeObject diff --git a/tests/run/unicodemethods.pyx b/tests/run/unicodemethods.pyx index 72606e7..c466d54 100644 --- a/tests/run/unicodemethods.pyx +++ b/tests/run/unicodemethods.pyx @@ -8,6 +8,9 @@ PY_VERSION = sys.version_info text = u'ab jd sdflk as sa sadas asdas fsdf ' sep = u' ' +format1 = u'abc%sdef' +format2 = u'abc%sdef%sghi' +unicode_sa = u'sa' multiline_text = u'''\ ab jd @@ -383,6 +386,122 @@ def in_test(unicode s, substring): return substring in s +# unicode.__concat__(s, suffix) + +def concat_any(unicode s, suffix): + """ + >>> concat(text, 'sa') == text + 'sa' or concat(text, 'sa') + True + >>> concat(None, 'sa') # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + >>> concat(text, None) # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + >>> class RAdd(object): + ... def __radd__(self, other): + ... return 123 + >>> concat(None, 'sa') # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + """ + assert cython.typeof(s + suffix) == 'Python object', cython.typeof(s + suffix) + return s + suffix + + +def concat(unicode s, str suffix): + """ + >>> concat(text, 'sa') == text + 'sa' or concat(text, 'sa') + True + >>> concat(None, 'sa') # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + >>> concat(text, None) # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + >>> class RAdd(object): + ... def __radd__(self, other): + ... return 123 + >>> concat(None, 'sa') # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ... + """ + assert cython.typeof(s + object()) == 'Python object', cython.typeof(s + object()) + assert cython.typeof(s + suffix) == 'unicode object', cython.typeof(s + suffix) + return s + suffix + + +def concat_literal_str(str suffix): + """ + >>> concat_literal_str('sa') == 'abcsa' or concat_literal_str('sa') + True + >>> concat_literal_str(None) # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ...NoneType... + """ + assert cython.typeof(u'abc' + object()) == 'Python object', cython.typeof(u'abc' + object()) + assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix) + return u'abc' + suffix + + +def concat_literal_unicode(unicode suffix): + """ + >>> concat_literal_unicode(unicode_sa) == 'abcsa' or concat_literal_unicode(unicode_sa) + True + >>> concat_literal_unicode(None) # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: ...NoneType... + """ + assert cython.typeof(u'abc' + suffix) == 'unicode object', cython.typeof(u'abc' + suffix) + return u'abc' + suffix + + +# unicode.__mod__(format, values) + +def mod_format(unicode s, values): + """ + >>> mod_format(format1, 'sa') == 'abcsadef' or mod_format(format1, 'sa') + True + >>> mod_format(format2, ('XYZ', 'ABC')) == 'abcXYZdefABCghi' or mod_format(format2, ('XYZ', 'ABC')) + True + >>> mod_format(None, 'sa') # doctest: +ELLIPSIS + Traceback (most recent call last): + TypeError: unsupported operand type(s) for %: 'NoneType' and 'str' + >>> class RMod(object): + ... def __rmod__(self, other): + ... return 123 + >>> mod_format(None, RMod()) + 123 + """ + assert cython.typeof(s % values) == 'Python object', cython.typeof(s % values) + return s % values + + +def mod_format_literal(values): + """ + >>> mod_format_literal('sa') == 'abcsadef' or mod_format(format1, 'sa') + True + >>> mod_format_literal(('sa',)) == 'abcsadef' or mod_format(format1, ('sa',)) + True + >>> mod_format_literal(['sa']) == "abc['sa']def" or mod_format(format1, ['sa']) + True + """ + assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values) + return u'abc%sdef' % values + + +def mod_format_tuple(*values): + """ + >>> mod_format_tuple('sa') == 'abcsadef' or mod_format(format1, 'sa') + True + >>> mod_format_tuple() + Traceback (most recent call last): + TypeError: not enough arguments for format string + """ + assert cython.typeof(u'abc%sdef' % values) == 'unicode object', cython.typeof(u'abc%sdef' % values) + return u'abc%sdef' % values + + # unicode.find(s, sub, [start, [end]]) @cython.test_fail_if_path_exists( -- 2.7.4