fix pycaffe input processing
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 31 Jul 2014 23:19:20 +0000 (16:19 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 6 Aug 2014 06:17:59 +0000 (23:17 -0700)
- load an image as [0,1] single / np.float32 according to Python convention
- fix input scaling during preprocessing:
  - scale input for preprocessing by `raw_scale` e.g. to map an image
    to [0, 255] for the CaffeNet and AlexNet ImageNet models
  - scale feature space by `input_scale` after mean subtraction
  - switch examples to raw scale for ImageNet models
  - fix #525
- preserve type after resizing.
- resize 1, 3, or K channel images with special casing between
  skimage.transform (1 and 3) and scipy.ndimage (K) for speed

examples/detection.ipynb
examples/filter_visualization.ipynb
examples/imagenet_classification.ipynb
examples/net_surgery.ipynb
python/caffe/_caffe.cpp
python/caffe/classifier.py
python/caffe/detector.py
python/caffe/io.py
python/caffe/pycaffe.py
python/classify.py
python/detect.py

index 3f2cf71..3b0a5b2 100644 (file)
@@ -36,7 +36,7 @@
      "input": [
       "!mkdir -p _temp\n",
       "!echo `pwd`/images/fish-bike.jpg > _temp/det_input.txt\n",
-      "!../python/detect.py --crop_mode=selective_search --pretrained_model=imagenet/caffe_rcnn_imagenet_model --model_def=imagenet/rcnn_imagenet_deploy.prototxt --gpu _temp/det_input.txt _temp/det_output.h5"
+      "!../python/detect.py --crop_mode=selective_search --pretrained_model=imagenet/caffe_rcnn_imagenet_model --model_def=imagenet/rcnn_imagenet_deploy.prototxt --gpu --raw_scale=255 _temp/det_input.txt _temp/det_output.h5"
      ],
      "language": "python",
      "metadata": {},
index 0fe863b..ea99f06 100644 (file)
@@ -66,8 +66,8 @@
       "net.set_mode_cpu()\n",
       "# input preprocessing: 'data' is the name of the input blob == net.inputs[0]\n",
       "net.set_mean('data', caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy')  # ImageNet mean\n",
-      "net.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB\n",
-      "net.set_input_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]"
+      "net.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]\n",
+      "net.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB"
      ],
      "language": "python",
      "metadata": {},
      "cell_type": "code",
      "collapsed": false,
      "input": [
-      "# our network takes BGR images, so we need to switch color channels\n",
-      "def showimage(im):\n",
-      "    if im.ndim == 3:\n",
-      "        im = im[:, :, ::-1]\n",
-      "    plt.imshow(im)\n",
-      "    \n",
       "# take an array of shape (n, height, width) or (n, height, width, channels)\n",
       "#  and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)\n",
       "def vis_square(data, padsize=1, padval=0):\n",
       "    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))\n",
       "    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])\n",
       "    \n",
-      "    showimage(data)"
+      "    imshow(data)"
      ],
      "language": "python",
      "metadata": {},
      "collapsed": false,
      "input": [
       "# index four is the center crop\n",
-      "image = net.blobs['data'].data[4].copy()\n",
-      "image -= image.min()\n",
-      "image /= image.max()\n",
-      "showimage(image.transpose(1, 2, 0))"
+      "imshow(net.deprocess('data', net.blobs['data'].data[4]))"
      ],
      "language": "python",
      "metadata": {},
    "metadata": {}
   }
  ]
-}
\ No newline at end of file
+}
index 8ab65fd..60e8bd0 100644 (file)
@@ -53,7 +53,7 @@
      "cell_type": "markdown",
      "metadata": {},
      "source": [
-      "Loading a network is easy. `caffe.Classifier` takes care of everything. Note the arguments for configuring input preprocessing: mean subtraction switched on by giving a mean file, input channel swapping takes care of mapping RGB into the reference ImageNet model's BGR order, and input scaling multiplies the feature scale from the input [0,1] to [0,255]."
+      "Loading a network is easy. `caffe.Classifier` takes care of everything. Note the arguments for configuring input preprocessing: mean subtraction switched on by giving a mean file, input channel swapping takes care of mapping RGB into the reference ImageNet model's BGR order, and raw scaling multiplies the feature scale from the input [0,1] to the ImageNet model's [0,255]."
      ]
     },
     {
@@ -63,7 +63,7 @@
       "net = caffe.Classifier(MODEL_FILE, PRETRAINED,\n",
       "                       mean_file=caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy',\n",
       "                       channel_swap=(2,1,0),\n",
-      "                       input_scale=255,\n",
+      "                       raw_scale=255,\n",
       "                       image_dims=(256, 256))"
      ],
      "language": "python",
index bf3d114..31847b2 100644 (file)
       "plt.imshow(im)\n",
       "net_full_conv.set_mean('data', '../python/caffe/imagenet/ilsvrc_2012_mean.npy')\n",
       "net_full_conv.set_channel_swap('data', (2,1,0))\n",
-      "net_full_conv.set_input_scale('data', 255.0)\n",
+      "net_full_conv.set_raw_scale('data', 255.0)\n",
       "# make classification map by forward pass and show top prediction index per location\n",
       "out = net_full_conv.forward_all(data=np.asarray([net_full_conv.preprocess('data', im)]))\n",
       "out['prob'][0].argmax(axis=0)"
    "metadata": {}
   }
  ]
-}
\ No newline at end of file
+}
index 30c86ae..5931772 100644 (file)
@@ -278,6 +278,7 @@ struct CaffeNet {
   // Input preprocessing configuration attributes.
   dict mean_;
   dict input_scale_;
+  dict raw_scale_;
   dict channel_swap_;
   // if taking input from an ndarray, we need to hold references
   object input_data_;
@@ -329,6 +330,7 @@ BOOST_PYTHON_MODULE(_caffe) {
       .add_property("outputs",      &CaffeNet::outputs)
       .add_property("mean",         &CaffeNet::mean_)
       .add_property("input_scale",  &CaffeNet::input_scale_)
+      .add_property("raw_scale",    &CaffeNet::raw_scale_)
       .add_property("channel_swap", &CaffeNet::channel_swap_)
       .def("_set_input_arrays",     &CaffeNet::set_input_arrays)
       .def("save",                  &CaffeNet::save);
index f347be4..48835ba 100644 (file)
@@ -14,13 +14,14 @@ class Classifier(caffe.Net):
     by scaling, center cropping, or oversampling.
     """
     def __init__(self, model_file, pretrained_file, image_dims=None,
-                 gpu=False, mean_file=None, input_scale=None, channel_swap=None):
+                 gpu=False, mean_file=None, input_scale=None, raw_scale=None,
+                 channel_swap=None):
         """
         Take
         image_dims: dimensions to scale input for cropping/sampling.
-                    Default is to scale to net input size for whole-image crop.
-        gpu, mean_file, input_scale, channel_swap: convenience params for
-            setting mode, mean, input scale, and channel order.
+            Default is to scale to net input size for whole-image crop.
+        gpu, mean_file, input_scale, raw_scale, channel_swap: params for
+            preprocessing options.
         """
         caffe.Net.__init__(self, model_file, pretrained_file)
         self.set_phase_test()
@@ -32,9 +33,11 @@ class Classifier(caffe.Net):
 
         if mean_file:
             self.set_mean(self.inputs[0], mean_file)
-        if input_scale:
+        if input_scale is not None:
             self.set_input_scale(self.inputs[0], input_scale)
-        if channel_swap:
+        if raw_scale is not None:
+            self.set_raw_scale(self.inputs[0], raw_scale)
+        if channel_swap is not None:
             self.set_channel_swap(self.inputs[0], channel_swap)
 
         self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:])
index 56c26ae..a9b06cd 100644 (file)
@@ -25,11 +25,12 @@ class Detector(caffe.Net):
     selective search proposals.
     """
     def __init__(self, model_file, pretrained_file, gpu=False, mean_file=None,
-                 input_scale=None, channel_swap=None, context_pad=None):
+                 input_scale=None, raw_scale=None, channel_swap=None,
+                 context_pad=None):
         """
         Take
-        gpu, mean_file, input_scale, channel_swap: convenience params for
-            setting mode, mean, input scale, and channel order.
+        gpu, mean_file, input_scale, raw_scale, channel_swap: params for
+            preprocessing options.
         context_pad: amount of surrounding context to take s.t. a `context_pad`
             sized border of pixels in the network input image is context, as in
             R-CNN feature extraction.
@@ -44,9 +45,11 @@ class Detector(caffe.Net):
 
         if mean_file:
             self.set_mean(self.inputs[0], mean_file)
-        if input_scale:
+        if input_scale is not None:
             self.set_input_scale(self.inputs[0], input_scale)
-        if channel_swap:
+        if raw_scale is not None:
+            self.set_raw_scale(self.inputs[0], raw_scale)
+        if channel_swap is not None:
             self.set_channel_swap(self.inputs[0], channel_swap)
 
         self.configure_crop(context_pad)
@@ -180,7 +183,7 @@ class Detector(caffe.Net):
         """
         self.context_pad = context_pad
         if self.context_pad:
-            input_scale = self.input_scale.get(self.inputs[0])
+            raw_scale = self.raw_scale.get(self.inputs[0])
             channel_order = self.channel_swap.get(self.inputs[0])
             # Padding context crops needs the mean in unprocessed input space.
             self.crop_mean = self.mean[self.inputs[0]].copy()
@@ -188,4 +191,4 @@ class Detector(caffe.Net):
             channel_order_inverse = [channel_order.index(i)
                                      for i in range(self.crop_mean.shape[2])]
             self.crop_mean = self.crop_mean[:,:, channel_order_inverse]
-            self.crop_mean /= input_scale
+            self.crop_mean /= raw_scale
index 1fc9723..aabcfdd 100644 (file)
@@ -1,6 +1,7 @@
 import numpy as np
 import skimage.io
-import skimage.transform
+from scipy.ndimage import zoom
+from skimage.transform import resize
 
 from caffe.proto import caffe_pb2
 
@@ -15,7 +16,8 @@ def load_image(filename, color=True):
         loads as intensity (if image is already grayscale).
 
     Give
-    image: an image with type np.float32 of size (H x W x 3) in RGB or
+    image: an image with type np.float32 in range [0, 1]
+        of size (H x W x 3) in RGB or
         of size (H x W x 1) in grayscale.
     """
     img = skimage.img_as_float(skimage.io.imread(filename)).astype(np.float32)
@@ -40,7 +42,17 @@ def resize_image(im, new_dims, interp_order=1):
     Give
     im: resized ndarray with shape (new_dims[0], new_dims[1], K)
     """
-    return skimage.transform.resize(im, new_dims, order=interp_order)
+    if im.shape[-1] == 1 or im.shape[-1] == 3:
+        # skimage is fast but only understands {1,3} channel images in [0, 1].
+        im_min, im_max = im.min(), im.max()
+        im_std = (im - im_min) / (im_max - im_min)
+        resized_std = resize(im_std, new_dims, order=interp_order)
+        resized_im = resized_std * (im_max - im_min) + im_min
+    else:
+        # ndimage interpolates anything but more slowly.
+        scale = tuple(np.array(new_dims) / np.array(im.shape[:2]))
+        resized_im = zoom(im, scale + (1,), order=interp_order)
+    return resized_im.astype(np.float32)
 
 
 def oversample(images, crop_dims):
index 64747f3..43648d0 100644 (file)
@@ -216,12 +216,10 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
     in_shape = self.blobs[input_].data.shape
     mean = np.load(mean_f)
     if mode == 'elementwise':
-        if mean.shape != in_shape[1:]:
-            # Resize mean (which requires H x W x K input in range [0,1]).
-            m_min, m_max = mean.min(), mean.max()
-            normal_mean = (mean - m_min) / (m_max - m_min)
-            mean = caffe.io.resize_image(normal_mean.transpose((1,2,0)),
-                    in_shape[2:]).transpose((2,0,1)) * (m_max - m_min) + m_min
+        if mean.shape[1:] != in_shape[2:]:
+            # Resize mean (which requires H x W x K input).
+            mean = caffe.io.resize_image(mean.transpose((1,2,0)),
+                                         in_shape[2:]).transpose((2,0,1))
         self.mean[input_] = mean
     elif mode == 'channel':
         self.mean[input_] = mean.mean(1).mean(1).reshape((in_shape[1], 1, 1))
@@ -229,10 +227,11 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
         raise Exception('Mode not in {}'.format(['elementwise', 'channel']))
 
 
-
 def _Net_set_input_scale(self, input_, scale):
     """
-    Set the input feature scaling factor s.t. input blob = input * scale.
+    Set the scale of preprocessed inputs s.t. the blob = blob * scale.
+    N.B. input_scale is done AFTER mean subtraction and other preprocessing
+    while raw_scale is done BEFORE.
 
     Take
     input_: which input to assign this scale factor
@@ -243,6 +242,22 @@ def _Net_set_input_scale(self, input_, scale):
     self.input_scale[input_] = scale
 
 
+def _Net_set_raw_scale(self, input_, scale):
+    """
+    Set the scale of raw features s.t. the input blob = input * scale.
+    While Python represents images in [0, 1], certain Caffe models
+    like CaffeNet and AlexNet represent images in [0, 255] so the raw_scale
+    of these models must be 255.
+
+    Take
+    input_: which input to assign this scale factor
+    scale: scale coefficient
+    """
+    if input_ not in self.inputs:
+        raise Exception('Input not in {}'.format(self.inputs))
+    self.raw_scale[input_] = scale
+
+
 def _Net_set_channel_swap(self, input_, order):
     """
     Set the input channel order for e.g. RGB to BGR conversion
@@ -263,10 +278,11 @@ def _Net_preprocess(self, input_name, input_):
     Format input for Caffe:
     - convert to single
     - resize to input dimensions (preserving number of channels)
-    - scale feature
     - reorder channels (for instance color to BGR)
-    - subtract mean
+    - scale raw input (e.g. from [0, 1] to [0, 255] for ImageNet models)
     - transpose dimensions to K x H x W
+    - subtract mean
+    - scale feature
 
     Take
     input_name: name of input blob to preprocess for
@@ -275,20 +291,23 @@ def _Net_preprocess(self, input_name, input_):
     Give
     caffe_inputs: (K x H x W) ndarray
     """
-    caffe_in = input_.astype(np.float32)
+    caffe_in = input_.astype(np.float32, copy=False)
     mean = self.mean.get(input_name)
     input_scale = self.input_scale.get(input_name)
+    raw_scale = self.raw_scale.get(input_name)
     channel_order = self.channel_swap.get(input_name)
     in_size = self.blobs[input_name].data.shape[2:]
     if caffe_in.shape[:2] != in_size:
         caffe_in = caffe.io.resize_image(caffe_in, in_size)
-    if input_scale is not None:
-        caffe_in *= input_scale
     if channel_order is not None:
         caffe_in = caffe_in[:, :, channel_order]
     caffe_in = caffe_in.transpose((2, 0, 1))
+    if raw_scale is not None:
+        caffe_in *= raw_scale
     if mean is not None:
         caffe_in -= mean
+    if input_scale is not None:
+        caffe_in *= input_scale
     return caffe_in
 
 
@@ -299,16 +318,19 @@ def _Net_deprocess(self, input_name, input_):
     decaf_in = input_.copy().squeeze()
     mean = self.mean.get(input_name)
     input_scale = self.input_scale.get(input_name)
+    raw_scale = self.raw_scale.get(input_name)
     channel_order = self.channel_swap.get(input_name)
+    if input_scale is not None:
+        decaf_in /= input_scale
     if mean is not None:
         decaf_in += mean
+    if raw_scale is not None:
+        decaf_in /= raw_scale
     decaf_in = decaf_in.transpose((1,2,0))
     if channel_order is not None:
         channel_order_inverse = [channel_order.index(i)
                                  for i in range(decaf_in.shape[2])]
         decaf_in = decaf_in[:, :, channel_order_inverse]
-    if input_scale is not None:
-        decaf_in /= input_scale
     return decaf_in
 
 
@@ -364,6 +386,7 @@ Net.forward_all = _Net_forward_all
 Net.forward_backward_all = _Net_forward_backward_all
 Net.set_mean = _Net_set_mean
 Net.set_input_scale = _Net_set_input_scale
+Net.set_raw_scale = _Net_set_raw_scale
 Net.set_channel_swap = _Net_set_channel_swap
 Net.preprocess = _Net_preprocess
 Net.deprocess = _Net_deprocess
index fdaeeb0..417f8b5 100755 (executable)
@@ -66,8 +66,12 @@ def main(argv):
     parser.add_argument(
         "--input_scale",
         type=float,
-        default=255,
-        help="Multiply input features by this scale before input to net"
+        help="Multiply input features by this scale to finish input preprocessing."
+    )
+    parser.add_argument(
+        "--raw_scale",
+        type=float,
+        help="Multiply raw input by this scale before preprocessing."
     )
     parser.add_argument(
         "--channel_swap",
index a3bee5c..4cfe082 100755 (executable)
@@ -76,8 +76,12 @@ def main(argv):
     parser.add_argument(
         "--input_scale",
         type=float,
-        default=255,
-        help="Multiply input features by this scale before input to net"
+        help="Multiply input features by this scale to finish input preprocessing."
+    )
+    parser.add_argument(
+        "--raw_scale",
+        type=float,
+        help="Multiply raw input by this scale before preprocessing."
     )
     parser.add_argument(
         "--channel_swap",
@@ -99,7 +103,8 @@ def main(argv):
     # Make detector.
     detector = caffe.Detector(args.model_def, args.pretrained_model,
             gpu=args.gpu, mean_file=args.mean_file,
-            input_scale=args.input_scale, channel_swap=channel_swap,
+            input_scale=args.input_scale, raw_scale=args.raw_scale,
+            channel_swap=channel_swap,
             context_pad=args.context_pad)
 
     if args.gpu: