set input preprocessing per blob in python
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 02:56:14 +0000 (19:56 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 14 May 2014 20:44:02 +0000 (13:44 -0700)
python/caffe/pycaffe.py

index 8390667..a7bc278 100644 (file)
@@ -12,12 +12,10 @@ from ._caffe import Net, SGDSolver
 # inheritance) so that nets created by caffe (e.g., by SGDSolver) will
 # automatically have the improved interface
 
-Net.input = property(lambda self: self.blobs.values()[0])
-Net.input_scale = None  # for a model that expects data = input * input_scale
-
-Net.output = property(lambda self: self.blobs.values()[-1])
-
-Net.mean = None  # image mean (ndarray, input dimensional or broadcastable)
+# Input preprocessing
+Net.mean = {}   # image mean (ndarray, input dimensional or broadcastable)
+Net.input_scale = {}  # for a model that expects data = input * input_scale
+Net.channel_swap = {}  # for RGB -> BGR and the like
 
 
 @property
@@ -44,33 +42,69 @@ def _Net_params(self):
 Net.params = _Net_params
 
 
-def _Net_set_mean(self, mean_f, mode='image'):
+def _Net_set_mean(self, input_, mean_f, mode='image'):
   """
   Set the mean to subtract for data centering.
 
   Take
+    input_: which input to assign this mean.
     mean_f: path to mean .npy
     mode: image = use the whole-image mean (and check dimensions)
           channel = channel constant (i.e. mean pixel instead of mean image)
   """
+  if input_ not in self.inputs:
+    raise Exception('Input not in {}'.format(self.inputs))
   mean = np.load(mean_f)
   if mode == 'image':
     if mean.shape != self.input.data.shape[1:]:
       raise Exception('The mean shape does not match the input shape.')
-    self.mean = mean
+    self.mean[input_] = mean
   elif mode == 'channel':
-    self.mean = mean.mean(1).mean(1)
+    self.mean[input_] = mean.mean(1).mean(1)
   else:
     raise Exception('Mode not in {}'.format(['image', 'channel']))
 
 Net.set_mean = _Net_set_mean
 
 
-def _Net_format_image(self, image):
+def _Net_set_input_scale(self, input_, scale):
+  """
+  Set the input feature scaling factor s.t. input blob = input * scale.
+
+  Take
+    input_: which input to assign this scale factor
+    scale: scale coefficient
+  """
+  if input_ not in self.inputs:
+    raise Exception('Input not in {}'.format(self.inputs))
+  self.input_scale[input_] = scale
+
+Net.set_input_scale = _Net_set_input_scale
+
+
+def _Net_set_channel_swap(self, input_, order):
+  """
+  Set the input channel order for e.g. RGB to BGR conversion
+  as needed for the reference ImageNet model.
+
+  Take
+    input_: which input to assign this channel order
+    order: the order to take the channels. (2,1,0) maps RGB to BGR for example.
+  """
+  if input_ not in self.inputs:
+    raise Exception('Input not in {}'.format(self.inputs))
+  self.channel_swap[input_] = order
+
+Net.set_channel_swap = _Net_set_channel_swap
+
+
+def _Net_format_image(self, input_, image):
   """
   Format image for input to Caffe:
   - convert to single
-  - reorder color to BGR
+  - scale feature
+  - reorder channels (for instance color to BGR)
+  - subtract mean
   - reshape to 1 x K x H x W
 
   Take
@@ -80,11 +114,15 @@ def _Net_format_image(self, image):
     image: (K x H x W) ndarray
   """
   caf_image = image.astype(np.float32)
-  if self.input_scale:
-    caf_image *= self.input_scale
-  caf_image = caf_image[:, :, ::-1]
-  if self.mean is not None:
-    caf_image -= self.mean
+  input_scale = self.input_scale.get(input_)
+  channel_order = self.channel_swap.get(input_)
+  mean = self.mean.get(input_)
+  if input_scale:
+    caf_image *= input_scale
+  if channel_order:
+    caf_image = caf_image[:, :, channel_order]
+  if mean:
+    caf_image -= mean
   caf_image = caf_image.transpose((2, 0, 1))
   caf_image = caf_image[np.newaxis, :, :, :]
   return caf_image
@@ -92,17 +130,21 @@ def _Net_format_image(self, image):
 Net.format_image = _Net_format_image
 
 
-def _Net_decaffeinate_image(self, image):
+def _Net_decaffeinate_image(self, input_, image):
   """
   Invert Caffe formatting; see _Net_format_image().
   """
   decaf_image = image.squeeze()
   decaf_image = decaf_image.transpose((1,2,0))
-  if self.mean is not None:
-    decaf_image += self.mean
-  decaf_image = decaf_image[:, :, ::-1]
-  if self.input_scale:
-    decaf_image /= self.input_scale
+  input_scale = self.input_scale.get(input_)
+  channel_order = self.channel_swap.get(input_)
+  mean = self.mean.get(input_)
+  if mean:
+    decaf_image += mean
+  if channel_order:
+    decaf_image = decaf_image[:, :, channel_order[::-1]]
+  if input_scale:
+    decaf_image /= input_scale
   return decaf_image
 
 Net.decaffeinate_image = _Net_decaffeinate_image