Merge pull request #2144 from tishibas/load_image-improved
[platform/upstream/caffeonacl.git] / python / caffe / io.py
index acd8a14..7531058 100644 (file)
@@ -8,34 +8,38 @@ try:
     from caffe.proto import caffe_pb2
 except:
     import sys
-    if sys.version_info >= (3,0):
+    if sys.version_info >= (3, 0):
         print("Failed to include caffe_pb2, things might go wrong!")
     else:
         raise
 
-## proto / datum / ndarray conversion
 
+## proto / datum / ndarray conversion
 def blobproto_to_array(blob, return_diff=False):
-    """Convert a blob proto to an array. In default, we will just return the data,
+    """
+    Convert a blob proto to an array. In default, we will just return the data,
     unless return_diff is True, in which case we will return the diff.
     """
+    # Read the data into an array
     if return_diff:
-        return np.array(blob.diff).reshape(
-            blob.num, blob.channels, blob.height, blob.width)
+        data = np.array(blob.diff)
     else:
-        return np.array(blob.data).reshape(
-            blob.num, blob.channels, blob.height, blob.width)
+        data = np.array(blob.data)
 
+    # Reshape the array
+    if blob.HasField('num') or blob.HasField('channels') or blob.HasField('height') or blob.HasField('width'):
+        # Use legacy 4D shape
+        return data.reshape(blob.num, blob.channels, blob.height, blob.width)
+    else:
+        return data.reshape(blob.shape.dim)
 
 def array_to_blobproto(arr, diff=None):
-    """Converts a 4-dimensional array to blob proto. If diff is given, also
+    """Converts a N-dimensional array to blob proto. If diff is given, also
     convert the diff. You need to make sure that arr and diff have the same
     shape, and this function does not do sanity check.
     """
-    if arr.ndim != 4:
-        raise ValueError('Incorrect array shape.')
     blob = caffe_pb2.BlobProto()
-    blob.num, blob.channels, blob.height, blob.width = arr.shape;
+    blob.shape.dim.extend(arr.shape)
     blob.data.extend(arr.astype(float).flat)
     if diff is not None:
         blob.diff.extend(diff.astype(float).flat)
@@ -81,7 +85,7 @@ def datum_to_array(datum):
     as one can easily get it by calling datum.label.
     """
     if len(datum.data):
-        return np.fromstring(datum.data, dtype = np.uint8).reshape(
+        return np.fromstring(datum.data, dtype=np.uint8).reshape(
             datum.channels, datum.height, datum.width)
     else:
         return np.array(datum.float_data).astype(float).reshape(
@@ -97,8 +101,9 @@ class Transformer:
     Note: this is mostly for illustrative purposes and it is likely better
     to define your own input preprocessing routine for your needs.
 
-    Take
-    net: a Net for which the input should be prepared
+    Parameters
+    ----------
+    net : a Net for which the input should be prepared
     """
     def __init__(self, inputs):
         self.inputs = inputs
@@ -108,13 +113,11 @@ class Transformer:
         self.mean = {}
         self.input_scale = {}
 
-
     def __check_input(self, in_):
         if in_ not in self.inputs:
             raise Exception('{} is not one of the net inputs: {}'.format(
                 in_, self.inputs))
 
-
     def preprocess(self, in_, data):
         """
         Format input for Caffe:
@@ -126,12 +129,14 @@ class Transformer:
         - subtract mean
         - scale feature
 
-        Take
-        in_: name of input blob to preprocess for
-        data: (H' x W' x K) ndarray
+        Parameters
+        ----------
+        in_ : name of input blob to preprocess for
+        data : (H' x W' x K) ndarray
 
-        Give
-        caffe_in: (K x H x W) ndarray for input to a Net
+        Returns
+        -------
+        caffe_in : (K x H x W) ndarray for input to a Net
         """
         self.__check_input(in_)
         caffe_in = data.astype(np.float32, copy=False)
@@ -155,7 +160,6 @@ class Transformer:
             caffe_in *= input_scale
         return caffe_in
 
-
     def deprocess(self, in_, data):
         """
         Invert Caffe formatting; see preprocess().
@@ -174,20 +178,20 @@ class Transformer:
         if raw_scale is not None:
             decaf_in /= raw_scale
         if channel_swap is not None:
-            decaf_in = decaf_in[channel_swap, :, :]
+            decaf_in = decaf_in[np.argsort(channel_swap), :, :]
         if transpose is not None:
-            decaf_in = decaf_in.transpose([transpose[t] for t in transpose])
+            decaf_in = decaf_in.transpose(np.argsort(transpose))
         return decaf_in
 
-
     def set_transpose(self, in_, order):
         """
         Set the input channel order for e.g. RGB to BGR conversion
         as needed for the reference ImageNet model.
 
-        Take
-        in_: which input to assign this channel order
-        order: the order to transpose the dimensions
+        Parameters
+        ----------
+        in_ : which input to assign this channel order
+        order : the order to transpose the dimensions
         """
         self.__check_input(in_)
         if len(order) != len(self.inputs[in_]) - 1:
@@ -195,16 +199,16 @@ class Transformer:
                             'dimensions as the input.')
         self.transpose[in_] = order
 
-
     def set_channel_swap(self, in_, order):
         """
         Set the input channel order for e.g. RGB to BGR conversion
         as needed for the reference ImageNet model.
         N.B. this assumes the channels are the first dimension AFTER transpose.
 
-        Take
-        in_: which input to assign this channel order
-        order: the order to take the channels.
+        Parameters
+        ----------
+        in_ : which input to assign this channel order
+        order : the order to take the channels.
             (2,1,0) maps RGB to BGR for example.
         """
         self.__check_input(in_)
@@ -213,7 +217,6 @@ class Transformer:
                             'dimensions as the input channels.')
         self.channel_swap[in_] = order
 
-
     def set_raw_scale(self, in_, scale):
         """
         Set the scale of raw features s.t. the input blob = input * scale.
@@ -221,21 +224,22 @@ class Transformer:
         like CaffeNet and AlexNet represent images in [0, 255] so the raw_scale
         of these models must be 255.
 
-        Take
-        in_: which input to assign this scale factor
-        scale: scale coefficient
+        Parameters
+        ----------
+        in_ : which input to assign this scale factor
+        scale : scale coefficient
         """
         self.__check_input(in_)
         self.raw_scale[in_] = scale
 
-
     def set_mean(self, in_, mean):
         """
         Set the mean to subtract for centering the data.
 
-        Take
-        in_: which input to assign this mean.
-        mean: mean ndarray (input dimensional or broadcastable)
+        Parameters
+        ----------
+        in_ : which input to assign this mean.
+        mean : mean ndarray (input dimensional or broadcastable)
         """
         self.__check_input(in_)
         ms = mean.shape
@@ -254,16 +258,16 @@ class Transformer:
                 raise ValueError('Mean shape incompatible with input shape.')
         self.mean[in_] = mean
 
-
     def set_input_scale(self, in_, scale):
         """
         Set the scale of preprocessed inputs s.t. the blob = blob * scale.
         N.B. input_scale is done AFTER mean subtraction and other preprocessing
         while raw_scale is done BEFORE.
 
-        Take
-        in_: which input to assign this scale factor
-        scale: scale coefficient
+        Parameters
+        ----------
+        in_ : which input to assign this scale factor
+        scale : scale coefficient
         """
         self.__check_input(in_)
         self.input_scale[in_] = scale
@@ -275,13 +279,16 @@ def load_image(filename, color=True):
     """
     Load an image converting from grayscale or alpha as needed.
 
-    Take
-    filename: string
-    color: flag for color format. True (default) loads as RGB while False
+    Parameters
+    ----------
+    filename : string
+    color : boolean
+        flag for color format. True (default) loads as RGB while False
         loads as intensity (if image is already grayscale).
 
-    Give
-    image: an image with type np.float32 in range [0, 1]
+    Returns
+    -------
+    image : an image with type np.float32 in range [0, 1]
         of size (H x W x 3) in RGB or
         of size (H x W x 1) in grayscale.
     """
@@ -299,29 +306,33 @@ def resize_image(im, new_dims, interp_order=1):
     """
     Resize an image array with interpolation.
 
-    Take
-    im: (H x W x K) ndarray
-    new_dims: (height, width) tuple of new dimensions.
-    interp_order: interpolation order, default is linear.
+    Parameters
+    ----------
+    im : (H x W x K) ndarray
+    new_dims : (height, width) tuple of new dimensions.
+    interp_order : interpolation order, default is linear.
 
-    Give
-    im: resized ndarray with shape (new_dims[0], new_dims[1], K)
+    Returns
+    -------
+    im : resized ndarray with shape (new_dims[0], new_dims[1], K)
     """
     if im.shape[-1] == 1 or im.shape[-1] == 3:
         im_min, im_max = im.min(), im.max()
         if im_max > im_min:
-            # skimage is fast but only understands {1,3} channel images in [0, 1].
+            # skimage is fast but only understands {1,3} channel images
+            # in [0, 1].
             im_std = (im - im_min) / (im_max - im_min)
             resized_std = resize(im_std, new_dims, order=interp_order)
             resized_im = resized_std * (im_max - im_min) + im_min
         else:
             # the image is a constant -- avoid divide by 0
-            ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]), dtype=np.float32)
+            ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]),
+                           dtype=np.float32)
             ret.fill(im_min)
             return ret
     else:
         # ndimage interpolates anything but more slowly.
-        scale = tuple(np.array(new_dims) / np.array(im.shape[:2]))
+        scale = tuple(np.array(new_dims, dtype=float) / np.array(im.shape[:2]))
         resized_im = zoom(im, scale + (1,), order=interp_order)
     return resized_im.astype(np.float32)
 
@@ -330,12 +341,14 @@ def oversample(images, crop_dims):
     """
     Crop images into the four corners, center, and their mirrored versions.
 
-    Take
-    image: iterable of (H x W x K) ndarrays
-    crop_dims: (height, width) tuple for the crops.
+    Parameters
+    ----------
+    image : iterable of (H x W x K) ndarrays
+    crop_dims : (height, width) tuple for the crops.
 
-    Give
-    crops: (10*N x H x W x K) ndarray of crops for number of inputs N.
+    Returns
+    -------
+    crops : (10*N x H x W x K) ndarray of crops for number of inputs N.
     """
     # Dimensions and center.
     im_shape = np.array(images[0].shape)
@@ -359,7 +372,7 @@ def oversample(images, crop_dims):
 
     # Extract crops
     crops = np.empty((10 * len(images), crop_dims[0], crop_dims[1],
-                            im_shape[-1]), dtype=np.float32)
+                      im_shape[-1]), dtype=np.float32)
     ix = 0
     for im in images:
         for crop in crops_ix: