extend semantics of 'basestring' typed variables to represent exactly bytes/str/unico...
authorStefan Behnel <stefan_ml@behnel.de>
Sun, 6 Oct 2013 09:52:53 +0000 (11:52 +0200)
committerStefan Behnel <stefan_ml@behnel.de>
Sun, 6 Oct 2013 09:52:53 +0000 (11:52 +0200)
CHANGES.rst
Cython/Compiler/Builtin.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
Cython/Utility/FunctionArguments.c
Cython/Utility/ModuleSetupCode.c
tests/run/builtin_basestring.pyx

index b58f438..4c6ad18 100644 (file)
@@ -8,6 +8,11 @@ Cython Changelog
 Features added
 --------------
 
+* Using ``cdef basestring stringvar`` and function arguments typed as
+  ``basestring`` is now meaningful and allows assigning exactly
+  ``bytes`` (Py2-only), ``str`` and ``unicode`` (Py2/Py3) objects,
+  but no subtypes of these types.
+
 * Support for the ``__debug__`` builtin.
 
 * Assertions in Cython compiled modules are disabled if the running
index 77826b9..c251830 100644 (file)
@@ -408,7 +408,7 @@ def init_builtins():
         '__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
         pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True)
     global list_type, tuple_type, dict_type, set_type, frozenset_type
-    global bytes_type, str_type, unicode_type
+    global bytes_type, str_type, unicode_type, basestring_type
     global float_type, bool_type, type_type, complex_type
     type_type  = builtin_scope.lookup('type').type
     list_type  = builtin_scope.lookup('list').type
@@ -419,6 +419,7 @@ def init_builtins():
     bytes_type = builtin_scope.lookup('bytes').type
     str_type   = builtin_scope.lookup('str').type
     unicode_type = builtin_scope.lookup('unicode').type
+    basestring_type = builtin_scope.lookup('basestring').type
     float_type = builtin_scope.lookup('float').type
     bool_type  = builtin_scope.lookup('bool').type
     complex_type  = builtin_scope.lookup('complex').type
index 8d619cf..25f2bca 100755 (executable)
@@ -1160,7 +1160,7 @@ class BytesNode(ConstNode):
         node = BytesNode(self.pos, value=self.value,
                          constant_result=self.constant_result)
         if dst_type.is_pyobject:
-            if dst_type in (py_object_type, Builtin.bytes_type):
+            if dst_type in (py_object_type, Builtin.bytes_type, Builtin.basestring_type):
                 node.type = Builtin.bytes_type
             else:
                 self.check_for_coercion_error(dst_type, env, fail=True)
@@ -1250,9 +1250,8 @@ class UnicodeNode(ConstNode):
                   "Unicode literals do not support coercion to C types other "
                   "than Py_UNICODE/Py_UCS4 (for characters) or Py_UNICODE* "
                   "(for strings).")
-        elif dst_type is not py_object_type:
-            if not self.check_for_coercion_error(dst_type, env):
-                self.fail_assignment(dst_type)
+        elif dst_type not in (py_object_type, Builtin.basestring_type):
+            self.check_for_coercion_error(dst_type, env, fail=True)
         return self
 
     def can_coerce_to_char_literal(self):
@@ -1337,7 +1336,8 @@ class StringNode(PyConstNode):
 #                return BytesNode(self.pos, value=self.value)
             if not dst_type.is_pyobject:
                 return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env)
-            self.check_for_coercion_error(dst_type, env, fail=True)
+            if dst_type is not Builtin.basestring_type:
+                self.check_for_coercion_error(dst_type, env, fail=True)
         return self
 
     def can_coerce_to_char_literal(self):
index 78cfd7c..18eda50 100755 (executable)
@@ -962,7 +962,10 @@ class BuiltinObjectType(PyObjectType):
 
     def assignable_from(self, src_type):
         if isinstance(src_type, BuiltinObjectType):
-            return src_type.name == self.name
+            if self.name == 'basestring':
+                return src_type.name in ('bytes', 'str', 'unicode', 'basestring')
+            else:
+                return src_type.name == self.name
         elif src_type.is_extension_type:
             # FIXME: This is an ugly special case that we currently
             # keep supporting.  It allows users to specify builtin
@@ -1005,7 +1008,15 @@ class BuiltinObjectType(PyObjectType):
         check = 'likely(%s(%s))' % (type_check, arg)
         if not notnone:
             check += '||((%s) == Py_None)' % arg
-        error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg)
+        if self.name == 'basestring':
+            name = '(PY_MAJOR_VERSION < 3 ? "basestring" : "str")'
+            space_for_name = 16
+        else:
+            name = '"%s"' % self.name
+            # avoid wasting too much space but limit number of different format strings
+            space_for_name = (len(self.name) // 16 + 1) * 16
+        error = '(PyErr_Format(PyExc_TypeError, "Expected %%.%ds, got %%.200s", %s, Py_TYPE(%s)->tp_name), 0)' % (
+            space_for_name, name, arg)
         return check + '||' + error
 
     def declaration_code(self, entity_code,
index 1c887a1..12064fc 100644 (file)
@@ -14,7 +14,10 @@ static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed
     }
     if (none_allowed && obj == Py_None) return 1;
     else if (exact) {
-        if (Py_TYPE(obj) == type) return 1;
+        if (likely(Py_TYPE(obj) == type)) return 1;
+        #if PY_MAJOR_VERSION == 2
+        else if ((type == &PyBaseString_Type) && __Pyx_PyBaseString_CheckExact(obj)) return 1;
+        #endif
     }
     else {
         if (PyObject_TypeCheck(obj, type)) return 1;
index bdc4c3e..df0e6b2 100644 (file)
 #else
   #define __Pyx_PyBaseString_Check(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj) || \
                                          PyString_Check(obj) || PyUnicode_Check(obj))
-  #define __Pyx_PyBaseString_CheckExact(obj) (Py_TYPE(obj) == &PyBaseString_Type)
+  #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj))
 #endif
 
 #if PY_VERSION_HEX < 0x02060000
index 22498e3..0a811ea 100644 (file)
@@ -37,3 +37,52 @@ def unicode_subtypes_basestring():
     True
     """
     return issubclass(unicode, basestring)
+
+
+def basestring_typed_variable(obj):
+    """
+    >>> basestring_typed_variable(None) is None
+    True
+    >>> basestring_typed_variable(ustring) is ustring
+    True
+    >>> basestring_typed_variable(sstring) is sstring
+    True
+    >>> if IS_PY3: print(True)
+    ... else: print(basestring_typed_variable(bstring) is bstring)
+    True
+    >>> class S(str): pass
+    >>> basestring_typed_variable(S())   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...got S...
+    """
+    cdef basestring s
+    s = u'abc'
+    assert s
+    s = 'abc'
+    assert s
+    s = b'abc'
+    assert s
+    # make sure coercion also works in conditional expressions
+    s = u'abc' if obj else b'abc' if obj else 'abc'
+    assert s
+    s = obj
+    return s
+
+
+def basestring_typed_argument(basestring obj):
+    """
+    >>> basestring_typed_argument(None) is None
+    True
+    >>> basestring_typed_argument(ustring) is ustring
+    True
+    >>> basestring_typed_argument(sstring) is sstring
+    True
+    >>> if IS_PY3: print(True)
+    ... else: print(basestring_typed_argument(bstring) is bstring)
+    True
+    >>> class S(str): pass
+    >>> basestring_typed_argument(S())   # doctest: +ELLIPSIS
+    Traceback (most recent call last):
+    TypeError: ...got S...
+    """
+    return obj