Allow old-style shape in blobproto_to_array
authorLuke Yeager <luke.yeager@gmail.com>
Thu, 15 Oct 2015 18:03:09 +0000 (11:03 -0700)
committerLuke Yeager <luke.yeager@gmail.com>
Thu, 15 Oct 2015 20:57:38 +0000 (13:57 -0700)
Fixes #3199
Bug introduced in #3170

python/caffe/io.py
python/caffe/test/test_io.py [new file with mode: 0644]

index 40b7ac1..11c8426 100644 (file)
@@ -20,11 +20,18 @@ def blobproto_to_array(blob, return_diff=False):
     Convert a blob proto to an array. In default, we will just return the data,
     unless return_diff is True, in which case we will return the diff.
     """
+    # Read the data into an array
     if return_diff:
-        return np.array(blob.diff).reshape(*blob.shape.dim)
+        data = np.array(blob.diff)
     else:
-        return np.array(blob.data).reshape(*blob.shape.dim)
+        data = np.array(blob.data)
 
+    # Reshape the array
+    if blob.HasField('num') or blob.HasField('channels') or blob.HasField('height') or blob.HasField('width'):
+        # Use legacy 4D shape
+        return data.reshape(blob.num, blob.channels, blob.height, blob.width)
+    else:
+        return data.reshape(blob.shape.dim)
 
 def array_to_blobproto(arr, diff=None):
     """Converts a N-dimensional array to blob proto. If diff is given, also
diff --git a/python/caffe/test/test_io.py b/python/caffe/test/test_io.py
new file mode 100644 (file)
index 0000000..8c86ef7
--- /dev/null
@@ -0,0 +1,41 @@
+import numpy as np
+import unittest
+
+import caffe
+
+class TestBlobProtoToArray(unittest.TestCase):
+
+    def test_old_format(self):
+        data = np.zeros((10,10))
+        blob = caffe.proto.caffe_pb2.BlobProto()
+        blob.data.extend(list(data.flatten()))
+        shape = (1,1,10,10)
+        blob.num, blob.channels, blob.height, blob.width = shape
+
+        arr = caffe.io.blobproto_to_array(blob)
+        self.assertEqual(arr.shape, shape)
+
+    def test_new_format(self):
+        data = np.zeros((10,10))
+        blob = caffe.proto.caffe_pb2.BlobProto()
+        blob.data.extend(list(data.flatten()))
+        blob.shape.dim.extend(list(data.shape))
+
+        arr = caffe.io.blobproto_to_array(blob)
+        self.assertEqual(arr.shape, data.shape)
+
+    def test_no_shape(self):
+        data = np.zeros((10,10))
+        blob = caffe.proto.caffe_pb2.BlobProto()
+        blob.data.extend(list(data.flatten()))
+
+        with self.assertRaises(ValueError):
+            caffe.io.blobproto_to_array(blob)
+
+    def test_scalar(self):
+        data = np.ones((1)) * 123
+        blob = caffe.proto.caffe_pb2.BlobProto()
+        blob.data.extend(list(data.flatten()))
+
+        arr = caffe.io.blobproto_to_array(blob)
+        self.assertEqual(arr, 123)