fix padding for the last batch
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 19 May 2014 01:25:18 +0000 (18:25 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 20 May 2014 06:56:16 +0000 (23:56 -0700)
python/caffe/pycaffe.py

index 72ae5fb..9caa21b 100644 (file)
@@ -325,14 +325,15 @@ def _Net_batch(self, blobs):
 
     # Yield full batches.
     for b in range(num_batches):
-        for i in [b * batch_size]:
-            yield {name: blobs[name][i:i + batch_size] for name in blobs}
+        i = b * batch_size
+        yield {name: blobs[name][i:i + batch_size] for name in blobs}
 
     # Yield last padded batch, if any.
     if remainder > 0:
         padded_batch = {}
         for name in blobs:
-            padding = np.zeros((remainder,) + blobs[name].shape[1:])
+            padding = np.zeros((batch_size - remainder,)
+                               + blobs[name].shape[1:])
             padded_batch[name] = np.concatenate([blobs[name][-remainder:],
                                                  padding])
         yield padded_batch