pycaffe: allow unspecified mean. Fixes #671.
authorJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 18 Jul 2014 23:31:16 +0000 (16:31 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Sat, 19 Jul 2014 00:08:49 +0000 (17:08 -0700)
python/caffe/pycaffe.py

index 5c1512c..a2fb16b 100644 (file)
@@ -259,7 +259,6 @@ def _Net_preprocess(self, input_name, input_):
     caffe_in = input_.astype(np.float32)
     input_scale = self.input_scale.get(input_name)
     channel_order = self.channel_swap.get(input_name)
-    mean = self.mean.get(input_name)
     in_size = self.blobs[input_name].data.shape[2:]
     if caffe_in.shape[:2] != in_size:
         caffe_in = caffe.io.resize_image(caffe_in, in_size)
@@ -268,8 +267,8 @@ def _Net_preprocess(self, input_name, input_):
     if channel_order:
         caffe_in = caffe_in[:, :, channel_order]
     caffe_in = caffe_in.transpose((2, 0, 1))
-    if mean is not None:
-        caffe_in -= mean
+    if hasattr(self, 'mean'):
+        caffe_in -= self.mean.get(input_name, 0)
     return caffe_in
 
 
@@ -280,9 +279,8 @@ def _Net_deprocess(self, input_name, input_):
     decaf_in = input_.copy().squeeze()
     input_scale = self.input_scale.get(input_name)
     channel_order = self.channel_swap.get(input_name)
-    mean = self.mean.get(input_name)
-    if mean is not None:
-        decaf_in += mean
+    if hasattr(self, 'mean'):
+        decaf_in += self.mean.get(input_name, 0)
     decaf_in = decaf_in.transpose((1,2,0))
     if channel_order:
         channel_order_inverse = [channel_order.index(i)