fix python mean subtraction
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Mon, 19 May 2014 00:11:38 +0000 (17:11 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 20 May 2014 06:54:25 +0000 (23:54 -0700)
python/caffe/pycaffe.py

index acd9f2f..9b1ed80 100644 (file)
@@ -193,13 +193,15 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
     """
     if input_ not in self.inputs:
         raise Exception('Input not in {}'.format(self.inputs))
+    in_shape = self.blobs[input_].data.shape
     mean = np.load(mean_f)
     if mode == 'elementwise':
-        if mean.shape != self.input.data.shape[1:]:
-            raise Exception('The mean shape does not match the input shape.')
+        if mean.shape != in_shape[1:]:
+            mean = caffe.io.resize_image(mean.transpose((1,2,0)),
+                    in_shape[2:]).transpose((2,0,1))
         self.mean[input_] = mean
     elif mode == 'channel':
-        self.mean[input_] = mean.mean(1).mean(1)
+        self.mean[input_] = mean.mean(1).mean(1).reshape((in_shape[1], 1, 1))
     else:
         raise Exception('Mode not in {}'.format(['elementwise', 'channel']))
 
@@ -265,9 +267,9 @@ def _Net_preprocess(self, input_name, inputs):
             caffe_in *= input_scale
         if channel_order:
             caffe_in = caffe_in[:, :, channel_order]
-        if mean:
-            caffe_in -= mean
         caffe_in = caffe_in.transpose((2, 0, 1))
+        if mean is not None:
+            caffe_in -= mean
         caffe_inputs.append(caffe_in)
     return np.asarray(caffe_inputs)
 
@@ -279,12 +281,12 @@ def _Net_deprocess(self, input_name, inputs):
     decaf_inputs = []
     for in_ in inputs:
         decaf_in = in_.squeeze()
-        decaf_in = decaf_in.transpose((1,2,0))
         input_scale = self.input_scale.get(input_name)
         channel_order = self.channel_swap.get(input_name)
         mean = self.mean.get(input_name)
-        if mean:
+        if mean is not None:
             decaf_in += mean
+        decaf_in = decaf_in.transpose((1,2,0))
         if channel_order:
             decaf_in = decaf_in[:, :, channel_order[::-1]]
         if input_scale: