optimise indexing and slicing of bytearray
authorStefan Behnel <stefan_ml@behnel.de>
Sun, 3 Nov 2013 15:01:05 +0000 (16:01 +0100)
committerStefan Behnel <stefan_ml@behnel.de>
Sun, 3 Nov 2013 15:01:05 +0000 (16:01 +0100)
Cython/Compiler/ExprNodes.py
Cython/Utility/StringTools.c
tests/run/bytearray_coercion.pyx

index 2376cbe..cfda7c0 100644 (file)
@@ -2723,7 +2723,8 @@ class IndexNode(ExprNode):
             elif base_type.is_pyunicode_ptr:
                 # sliced Py_UNICODE* strings must coerce to Python
                 return unicode_type
-            elif base_type in (unicode_type, bytes_type, str_type, list_type, tuple_type):
+            elif base_type in (unicode_type, bytes_type, str_type,
+                               bytearray_type, list_type, tuple_type):
                 # slicing these returns the same type
                 return base_type
             else:
@@ -2745,6 +2746,8 @@ class IndexNode(ExprNode):
             elif base_type is str_type:
                 # always returns str - Py2: bytes, Py3: unicode
                 return base_type
+            elif base_type is bytearray_type:
+                return PyrexTypes.c_uchar_type
             elif isinstance(self.base, BytesNode):
                 #if env.global_scope().context.language_level >= 3:
                 #    # inferring 'char' can be made to work in Python 3 mode
@@ -3014,7 +3017,7 @@ class IndexNode(ExprNode):
             if base_type.is_pyobject:
                 if self.index.type.is_int:
                     if (not setting
-                        and (base_type in (list_type, tuple_type))
+                        and (base_type in (list_type, tuple_type, bytearray_type))
                         and (not self.index.type.signed
                              or not env.directives['wraparound']
                              or (isinstance(self.index, IntNode) and
@@ -3032,6 +3035,9 @@ class IndexNode(ExprNode):
                     # Py_UNICODE/Py_UCS4 will automatically coerce to a unicode string
                     # if required, so this is fast and safe
                     self.type = PyrexTypes.c_py_ucs4_type
+                elif self.index.type.is_int and base_type is bytearray_type:
+                    # not using uchar here to enable error reporting as '-1'
+                    self.type = PyrexTypes.c_int_type
                 elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
                     self.type = base_type
                 else:
@@ -3230,15 +3236,21 @@ class IndexNode(ExprNode):
             return "(*%s)" % self.buffer_ptr_code
         elif self.is_memslice_copy:
             return self.base.result()
-        elif self.base.type is list_type:
-            return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
-        elif self.base.type is tuple_type:
-            return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
-        elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
-            error(self.pos, "Invalid use of pointer slice")
+        elif self.base.type in (list_type, tuple_type, bytearray_type):
+            if self.base.type is list_type:
+                index_code = "PyList_GET_ITEM(%s, %s)"
+            elif self.base.type is tuple_type:
+                index_code = "PyTuple_GET_ITEM(%s, %s)"
+            elif self.base.type is bytearray_type:
+                index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))"
+            else:
+                assert False, "unexpected base type in indexing: %s" % self.base.type
         else:
-            return "(%s[%s])" % (
-                self.base.result(), self.index.result())
+            if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
+                error(self.pos, "Invalid use of pointer slice")
+                return
+            index_code = "(%s[%s])"
+        return index_code % (self.base.result(), self.index.result())
 
     def extra_index_params(self, code):
         if self.index.type.is_int:
@@ -3344,6 +3356,22 @@ class IndexNode(ExprNode):
                         self.extra_index_params(code),
                         self.result(),
                         code.error_goto(self.pos)))
+            elif self.base.type is bytearray_type:
+                assert self.index.type.is_int
+                assert self.type.is_int
+                index_code = self.index.result()
+                function = "__Pyx_GetItemInt_ByteArray"
+                code.globalstate.use_utility_code(
+                    UtilityCode.load_cached("GetItemIntByteArray", "StringTools.c"))
+                code.putln(
+                    "%s = %s(%s, %s%s); if (unlikely(%s == -1)) %s;" % (
+                        self.result(),
+                        function,
+                        self.base.py_result(),
+                        index_code,
+                        self.extra_index_params(code),
+                        self.result(),
+                        code.error_goto(self.pos)))
 
     def generate_setitem_code(self, value_code, code):
         if self.index.type.is_int:
index 1bae48f..2f17e53 100644 (file)
@@ -227,6 +227,49 @@ static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int eq
 #endif
 }
 
+//////////////////// GetItemIntByteArray.proto ////////////////////
+
+#define __Pyx_GetItemInt_ByteArray(o, i, size, to_py_func, is_list, wraparound, boundscheck) \
+    (((size) <= sizeof(Py_ssize_t)) ? \
+    __Pyx_GetItemInt_ByteArray_Fast(o, i, wraparound, boundscheck) : \
+    __Pyx_GetItemInt_ByteArray_Generic(o, to_py_func(i)))
+
+static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i,
+                                                         int wraparound, int boundscheck);
+static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, PyObject* j);
+
+//////////////////// GetItemIntByteArray ////////////////////
+
+static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Fast(PyObject* string, Py_ssize_t i,
+                                                         int wraparound, int boundscheck) {
+    Py_ssize_t length;
+    if (wraparound | boundscheck) {
+        length = PyByteArray_GET_SIZE(string);
+        if (wraparound & unlikely(i < 0)) i += length;
+        if ((!boundscheck) || likely((0 <= i) & (i < length))) {
+            return (unsigned char) (PyByteArray_AS_STRING(string)[i]);
+        } else {
+            PyErr_SetString(PyExc_IndexError, "bytearray index out of range");
+            return -1;
+        }
+    } else {
+        return (unsigned char) (PyByteArray_AS_STRING(string)[i]);
+    }
+}
+
+static CYTHON_INLINE int __Pyx_GetItemInt_ByteArray_Generic(PyObject* string, PyObject* j) {
+    unsigned char bchar;
+    PyObject *bchar_string;
+    if (!j) return -1;
+    bchar_string = PyObject_GetItem(string, j);
+    Py_DECREF(j);
+    if (!bchar_string) return -1;
+    bchar = (unsigned char) (PyByteArray_AS_STRING(bchar_string)[0]);
+    Py_DECREF(bchar_string);
+    return bchar;
+}
+
+
 //////////////////// GetItemIntUnicode.proto ////////////////////
 
 #define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func, is_list, wraparound, boundscheck) \
index 47b1ec2..be414be 100644 (file)
@@ -3,6 +3,8 @@
 # NOTE: Py2.6+ only
 
 
+cimport cython
+
 cpdef bytearray coerce_to_charptr(char* b):
     """
     >>> b = bytearray(b'abc')
@@ -35,3 +37,29 @@ cpdef bytearray coerce_charptr_slice(char* b):
     True
     """
     return b[:2]
+
+def infer_index_types(bytearray b):
+    """
+    >>> b = bytearray(b'a\\xFEc')
+    >>> print(infer_index_types(b))
+    (254, 254, 254, 'unsigned char', 'unsigned char', 'unsigned char', 'int')
+    """
+    c = b[1]
+    with cython.wraparound(False):
+        d = b[1]
+    with cython.boundscheck(False):
+        e = b[1]
+    return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1])
+
+def infer_slice_types(bytearray b):
+    """
+    >>> b = bytearray(b'abc')
+    >>> print(infer_slice_types(b))
+    (bytearray(b'bc'), bytearray(b'bc'), bytearray(b'bc'), 'Python object', 'Python object', 'Python object', 'bytearray object')
+    """
+    c = b[1:]
+    with cython.boundscheck(False):
+        d = b[1:]
+    with cython.boundscheck(False), cython.wraparound(False):
+        e = b[1:]
+    return c, d, e, cython.typeof(c), cython.typeof(d), cython.typeof(e), cython.typeof(b[1:])