Error checking for NULL strides + tests
authorMark Florisson <markflorisson88@gmail.com>
Sun, 23 Sep 2012 11:05:48 +0000 (12:05 +0100)
committerMark Florisson <markflorisson88@gmail.com>
Sun, 23 Sep 2012 11:05:48 +0000 (12:05 +0100)
Cython/Utility/MemoryView_C.c
tests/memoryview/numpy_memoryview.pyx

index c2524b0..bab3cad 100644 (file)
@@ -207,30 +207,48 @@ static int __Pyx_ValidateAndInit_memviewslice(
         goto fail;
     }
 
-    for(i=0; i<ndim; i++) {
+    for (i = 0; i < ndim; i++) {
         spec = axes_specs[i];
 
-        if (spec & __Pyx_MEMVIEW_CONTIG) {
-            if (spec & (__Pyx_MEMVIEW_PTR|__Pyx_MEMVIEW_FULL)) {
-                if (buf->strides[i] != sizeof(void *)) {
-                    PyErr_Format(PyExc_ValueError,
-                        "Buffer is not indirectly contiguous in dimension %d.", i);
+        if (buf->strides) {
+            if (spec & __Pyx_MEMVIEW_CONTIG) {
+                if (spec & (__Pyx_MEMVIEW_PTR|__Pyx_MEMVIEW_FULL)) {
+                    if (buf->strides[i] != sizeof(void *)) {
+                        PyErr_Format(PyExc_ValueError,
+                            "Buffer is not indirectly contiguous in dimension %d.", i);
+                        goto fail;
+                    }
+                } else if (buf->strides[i] != buf->itemsize) {
+                    PyErr_SetString(PyExc_ValueError,
+                        "Buffer and memoryview are not contiguous in the same dimension.");
                     goto fail;
                 }
-            } else if (buf->strides[i] != buf->itemsize) {
-                PyErr_SetString(PyExc_ValueError,
-                    "Buffer and memoryview are not contiguous in the same dimension.");
-                goto fail;
             }
-        }
 
-        if (spec & __Pyx_MEMVIEW_FOLLOW) {
-            Py_ssize_t stride = buf->strides[i];
-            if (stride < 0)
-                stride = -stride;
-            if (stride < buf->itemsize) {
+            if (spec & __Pyx_MEMVIEW_FOLLOW) {
+                Py_ssize_t stride = buf->strides[i];
+                if (stride < 0)
+                    stride = -stride;
+                if (stride < buf->itemsize) {
+                    PyErr_SetString(PyExc_ValueError,
+                        "Buffer and memoryview are not contiguous in the same dimension.");
+                    goto fail;
+                }
+            }
+        } else {
+            if (spec & __Pyx_MEMVIEW_CONTIG && i != ndim - 1) {
+                PyErr_Format(PyExc_ValueError,
+                             "C-contiguous buffer is not contiguous in "
+                             "dimension %d", i);
+                goto fail;
+            } else if (spec & (__Pyx_MEMVIEW_PTR)) {
+                PyErr_Format(PyExc_ValueError,
+                             "C-contiguous buffer is not indirect in "
+                             "dimension %d", i);
+                goto fail;
+            } else if (buf->suboffsets) {
                 PyErr_SetString(PyExc_ValueError,
-                    "Buffer and memoryview are not contiguous in the same dimension.");
+                                "Buffer exposes suboffsets but no strides");
                 goto fail;
             }
         }
@@ -254,25 +272,27 @@ static int __Pyx_ValidateAndInit_memviewslice(
         }
     }
 
-    if (c_or_f_flag & __Pyx_IS_F_CONTIG) {
-        Py_ssize_t stride = 1;
-        for(i=0; i<ndim; i++) {
-            if(stride * buf->itemsize != buf->strides[i]) {
-                PyErr_SetString(PyExc_ValueError,
-                    "Buffer not fortran contiguous.");
-                goto fail;
+    if (buf->strides) {
+        if (c_or_f_flag & __Pyx_IS_F_CONTIG) {
+            Py_ssize_t stride = 1;
+            for (i=0; i<ndim; i++) {
+                if (stride * buf->itemsize != buf->strides[i]) {
+                    PyErr_SetString(PyExc_ValueError,
+                        "Buffer not fortran contiguous.");
+                    goto fail;
+                }
+                stride = stride * buf->shape[i];
             }
-            stride = stride * buf->shape[i];
-        }
-    } else if (c_or_f_flag & __Pyx_IS_C_CONTIG) {
-        Py_ssize_t stride = 1;
-        for(i=ndim-1; i>-1; i--) {
-            if(stride * buf->itemsize != buf->strides[i]) {
-                PyErr_SetString(PyExc_ValueError,
-                    "Buffer not C contiguous.");
-                goto fail;
+        } else if (c_or_f_flag & __Pyx_IS_C_CONTIG) {
+            Py_ssize_t stride = 1;
+            for (i = ndim-1; i>-1; i--) {
+                if(stride * buf->itemsize != buf->strides[i]) {
+                    PyErr_SetString(PyExc_ValueError,
+                        "Buffer not C contiguous.");
+                    goto fail;
+                }
+                stride = stride * buf->shape[i];
             }
-            stride = stride * buf->shape[i];
         }
     }
 
index c21fee8..b45325b 100644 (file)
@@ -10,6 +10,7 @@ import sys
 cimport numpy as np
 import numpy as np
 cimport cython
+from cython cimport view
 
 include "cythonarrayutil.pxi"
 include "../buffers/mockbuffers.pxi"
@@ -568,3 +569,102 @@ def test_struct_attributes():
     print array[0]['attrib1']
     print array[0]['attrib2']
     print chr(array[0]['attrib3']['c'][0][0])
+
+#
+### Test for NULL strides (C contiguous buffers)
+#
+cdef getbuffer(Buffer self, Py_buffer *info):
+    info.buf = &self.m[0, 0]
+    info.len = 10 * 20
+    info.ndim = 2
+    info.shape = self._shape
+    info.strides = NULL
+    info.suboffsets = NULL
+    info.itemsize = 4
+    info.readonly = 0
+    self.format = b"f"
+    info.format = self.format
+
+cdef class Buffer(object):
+    cdef Py_ssize_t _shape[2]
+    cdef bytes format
+    cdef float[:, :] m
+    cdef object shape, strides
+
+    def __init__(self):
+        a = np.arange(200, dtype=np.float32).reshape(10, 20)
+        self.m = a
+        self.shape = a.shape
+        self.strides = a.strides
+        self._shape[0] = 10
+        self._shape[1] = 20
+
+    def __getbuffer__(self, Py_buffer *info, int flags):
+        getbuffer(self, info)
+
+cdef class SuboffsetsNoStridesBuffer(Buffer):
+    def __getbuffer__(self, Py_buffer *info, int flags):
+        getbuffer(self, info)
+        info.suboffsets = self._shape
+
+@testcase
+def test_null_strides(Buffer buffer_obj):
+    """
+    >>> test_null_strides(Buffer())
+    """
+    cdef float[:, :] m1 = buffer_obj
+    cdef float[:, ::1] m2 = buffer_obj
+    cdef float[:, ::view.contiguous] m3 = buffer_obj
+
+    assert (<object> m1).strides == buffer_obj.strides
+    assert (<object> m2).strides == buffer_obj.strides, ((<object> m2).strides, buffer_obj.strides)
+    assert (<object> m3).strides == buffer_obj.strides
+
+    cdef int i, j
+    for i in range(m1.shape[0]):
+        for j in range(m1.shape[1]):
+            assert m1[i, j] == buffer_obj.m[i, j]
+            assert m2[i, j] == buffer_obj.m[i, j], (i, j, m2[i, j], buffer_obj.m[i, j])
+            assert m3[i, j] == buffer_obj.m[i, j]
+
+@testcase
+def test_null_strides_error(buffer_obj):
+    """
+    >>> test_null_strides_error(Buffer())
+    C-contiguous buffer is not indirect in dimension 1
+    C-contiguous buffer is not indirect in dimension 0
+    C-contiguous buffer is not contiguous in dimension 0
+    C-contiguous buffer is not contiguous in dimension 0
+    >>> test_null_strides_error(SuboffsetsNoStridesBuffer())
+    Traceback (most recent call last):
+        ...
+    ValueError: Buffer exposes suboffsets but no strides
+    """
+    # valid
+    cdef float[::view.generic, ::view.generic] full_buf = buffer_obj
+
+    # invalid
+    cdef float[:, ::view.indirect] indirect_buf1
+    cdef float[::view.indirect, :] indirect_buf2
+    cdef float[::1, :] fortran_buf1
+    cdef float[::view.contiguous, :] fortran_buf2
+
+    try:
+        indirect_buf1 = buffer_obj
+    except ValueError, e:
+        print e
+
+    try:
+        indirect_buf2 = buffer_obj
+    except ValueError, e:
+        print e
+
+    try:
+        fortran_buf1 = buffer_obj
+    except ValueError, e:
+        print e
+
+    try:
+        fortran_buf2 = buffer_obj
+    except ValueError, e:
+        print e
\ No newline at end of file