bad forward/backward inputs throw exceptions instead of crashing python
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 21:02:54 +0000 (14:02 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 21:02:54 +0000 (14:02 -0700)
python/caffe/_caffe.cpp

index e1ee652..18b96b9 100644 (file)
@@ -162,15 +162,12 @@ struct CaffeNet {
   // Check that an array is acceptable for blob assignment
   // as described in the preface to Forward().
   inline void check_array_against_blob(
-      PyArrayObject* arr, Blob<float>* blob) {
-    CHECK(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS);
-    CHECK_EQ(PyArray_NDIM(arr), 4);
-    CHECK_EQ(PyArray_ITEMSIZE(arr), 4);
-    npy_intp* dims = PyArray_DIMS(arr);
-    CHECK_EQ(dims[0], blob->num());
-    CHECK_EQ(dims[1], blob->channels());
-    CHECK_EQ(dims[2], blob->height());
-    CHECK_EQ(dims[3], blob->width());
+      PyArrayObject* arr, Blob<float>* blob, string name) {
+    check_contiguous_array(arr, name, blob->channels(), blob->height(),
+        blob->width());
+    if (PyArray_DIMS(arr)[0] != blob->num()) {
+      throw std::runtime_error(name + " has wrong batch size");
+    }
   }
 
   // generate Python exceptions for badly shaped or discontiguous arrays
@@ -207,7 +204,8 @@ struct CaffeNet {
     for (int i = 0; i < input_blobs.size(); ++i) {
       object elem = bottom[i];
       PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
-      check_array_against_blob(arr, input_blobs[i]);
+      check_array_against_blob(arr, input_blobs[i],
+          net_->blob_names()[net_->input_blob_indices()[i]]);
       switch (Caffe::mode()) {
       case Caffe::CPU:
         memcpy(input_blobs[i]->mutable_cpu_data(), PyArray_DATA(arr),
@@ -227,7 +225,8 @@ struct CaffeNet {
     for (int i = 0; i < output_blobs.size(); ++i) {
       object elem = top[i];
       PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
-      check_array_against_blob(arr, output_blobs[i]);
+      check_array_against_blob(arr, output_blobs[i],
+          net_->blob_names()[net_->input_blob_indices()[i]]);
       switch (Caffe::mode()) {
       case Caffe::CPU:
         memcpy(PyArray_DATA(arr), output_blobs[i]->cpu_data(),
@@ -252,7 +251,8 @@ struct CaffeNet {
     for (int i = 0; i < output_blobs.size(); ++i) {
       object elem = top_diff[i];
       PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
-      check_array_against_blob(arr, output_blobs[i]);
+      check_array_against_blob(arr, output_blobs[i],
+          net_->blob_names()[net_->input_blob_indices()[i]]);
       switch (Caffe::mode()) {
       case Caffe::CPU:
         memcpy(output_blobs[i]->mutable_cpu_diff(), PyArray_DATA(arr),
@@ -272,7 +272,8 @@ struct CaffeNet {
     for (int i = 0; i < input_blobs.size(); ++i) {
       object elem = bottom_diff[i];
       PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(elem.ptr());
-      check_array_against_blob(arr, input_blobs[i]);
+      check_array_against_blob(arr, input_blobs[i],
+          net_->blob_names()[net_->input_blob_indices()[i]]);
       switch (Caffe::mode()) {
       case Caffe::CPU:
         memcpy(PyArray_DATA(arr), input_blobs[i]->cpu_diff(),