From 96cd02dd538bcdb793070b4c9320eadfc9c7962d Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Tue, 13 May 2014 19:56:14 -0700 Subject: [PATCH] set input preprocessing per blob in python --- python/caffe/pycaffe.py | 86 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 8390667..a7bc278 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -12,12 +12,10 @@ from ._caffe import Net, SGDSolver # inheritance) so that nets created by caffe (e.g., by SGDSolver) will # automatically have the improved interface -Net.input = property(lambda self: self.blobs.values()[0]) -Net.input_scale = None # for a model that expects data = input * input_scale - -Net.output = property(lambda self: self.blobs.values()[-1]) - -Net.mean = None # image mean (ndarray, input dimensional or broadcastable) +# Input preprocessing +Net.mean = {} # image mean (ndarray, input dimensional or broadcastable) +Net.input_scale = {} # for a model that expects data = input * input_scale +Net.channel_swap = {} # for RGB -> BGR and the like @property @@ -44,33 +42,69 @@ def _Net_params(self): Net.params = _Net_params -def _Net_set_mean(self, mean_f, mode='image'): +def _Net_set_mean(self, input_, mean_f, mode='image'): """ Set the mean to subtract for data centering. Take + input_: which input to assign this mean. mean_f: path to mean .npy mode: image = use the whole-image mean (and check dimensions) channel = channel constant (i.e. mean pixel instead of mean image) """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) mean = np.load(mean_f) if mode == 'image': if mean.shape != self.input.data.shape[1:]: raise Exception('The mean shape does not match the input shape.') - self.mean = mean + self.mean[input_] = mean elif mode == 'channel': - self.mean = mean.mean(1).mean(1) + self.mean[input_] = mean.mean(1).mean(1) else: raise Exception('Mode not in {}'.format(['image', 'channel'])) Net.set_mean = _Net_set_mean -def _Net_format_image(self, image): +def _Net_set_input_scale(self, input_, scale): + """ + Set the input feature scaling factor s.t. input blob = input * scale. + + Take + input_: which input to assign this scale factor + scale: scale coefficient + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + self.input_scale[input_] = scale + +Net.set_input_scale = _Net_set_input_scale + + +def _Net_set_channel_swap(self, input_, order): + """ + Set the input channel order for e.g. RGB to BGR conversion + as needed for the reference ImageNet model. + + Take + input_: which input to assign this channel order + order: the order to take the channels. (2,1,0) maps RGB to BGR for example. + """ + if input_ not in self.inputs: + raise Exception('Input not in {}'.format(self.inputs)) + self.channel_swap[input_] = order + +Net.set_channel_swap = _Net_set_channel_swap + + +def _Net_format_image(self, input_, image): """ Format image for input to Caffe: - convert to single - - reorder color to BGR + - scale feature + - reorder channels (for instance color to BGR) + - subtract mean - reshape to 1 x K x H x W Take @@ -80,11 +114,15 @@ def _Net_format_image(self, image): image: (K x H x W) ndarray """ caf_image = image.astype(np.float32) - if self.input_scale: - caf_image *= self.input_scale - caf_image = caf_image[:, :, ::-1] - if self.mean is not None: - caf_image -= self.mean + input_scale = self.input_scale.get(input_) + channel_order = self.channel_swap.get(input_) + mean = self.mean.get(input_) + if input_scale: + caf_image *= input_scale + if channel_order: + caf_image = caf_image[:, :, channel_order] + if mean: + caf_image -= mean caf_image = caf_image.transpose((2, 0, 1)) caf_image = caf_image[np.newaxis, :, :, :] return caf_image @@ -92,17 +130,21 @@ def _Net_format_image(self, image): Net.format_image = _Net_format_image -def _Net_decaffeinate_image(self, image): +def _Net_decaffeinate_image(self, input_, image): """ Invert Caffe formatting; see _Net_format_image(). """ decaf_image = image.squeeze() decaf_image = decaf_image.transpose((1,2,0)) - if self.mean is not None: - decaf_image += self.mean - decaf_image = decaf_image[:, :, ::-1] - if self.input_scale: - decaf_image /= self.input_scale + input_scale = self.input_scale.get(input_) + channel_order = self.channel_swap.get(input_) + mean = self.mean.get(input_) + if mean: + decaf_image += mean + if channel_order: + decaf_image = decaf_image[:, :, channel_order[::-1]] + if input_scale: + decaf_image /= input_scale return decaf_image Net.decaffeinate_image = _Net_decaffeinate_image -- 2.7.4