From 2fc32d558aa45f86d542b30bccbdf299bf93003e Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Mon, 19 May 2014 15:31:49 -0700 Subject: [PATCH] image classification in python --- python/caffe/__init__.py | 1 + python/caffe/classifier.py | 85 +++++++++++++++++++++++++ python/caffe/imagenet/__init__.py | 1 - python/caffe/imagenet/wrapper.py | 128 -------------------------------------- python/classify.py | 120 +++++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 129 deletions(-) create mode 100644 python/caffe/classifier.py delete mode 100644 python/caffe/imagenet/__init__.py delete mode 100644 python/caffe/imagenet/wrapper.py create mode 100755 python/classify.py diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py index e5e1062..e07b013 100644 --- a/python/caffe/__init__.py +++ b/python/caffe/__init__.py @@ -1,2 +1,3 @@ from .pycaffe import Net, SGDSolver +from .classifier import Classifier import io diff --git a/python/caffe/classifier.py b/python/caffe/classifier.py new file mode 100644 index 0000000..d1875c2 --- /dev/null +++ b/python/caffe/classifier.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +""" +Classifier is an image classifier specialization of Net. +""" + +import numpy as np + +import caffe + + +class Classifier(caffe.Net): + """ + Classifier extends Net for image class prediction + by scaling, center cropping, or oversampling. + """ + def __init__(self, model_file, pretrained_file, image_dims=None, + gpu=False, mean_file=None, input_scale=None, channel_swap=None): + """ + Take + image_dims: dimensions to scale input for cropping/sampling. + Default is to scale to net input size for whole-image crop. + gpu, mean_file, input_scale, channel_swap: convenience params for + setting mode, mean, input scale, and channel order. + """ + caffe.Net.__init__(self, model_file, pretrained_file) + self.set_phase_test() + + if gpu: + self.set_mode_gpu() + else: + self.set_mode_cpu() + + if mean_file: + self.set_mean(self.inputs[0], mean_file) + if input_scale: + self.set_input_scale(self.inputs[0], input_scale) + if channel_swap: + self.set_channel_swap(self.inputs[0], channel_swap) + + self.crop_dims = np.array(self.blobs[self.inputs[0]].data.shape[2:]) + if not image_dims: + image_dims = self.crop_dims + self.image_dims = image_dims + + + def predict(self, inputs, oversample=True): + """ + Predict classification probabilities of inputs. + + Take + inputs: iterable of (H x W x K) input ndarrays. + oversample: average predictions across center, corners, and mirrors + when True (default). Center-only prediction when False. + + Give + predictions: (N x C) ndarray of class probabilities + for N images and C classes. + """ + # Scale to standardize input dimensions. + inputs = np.asarray([caffe.io.resize_image(im, self.image_dims) + for im in inputs]) + + if oversample: + # Generate center, corner, and mirrored crops. + inputs = caffe.io.oversample(inputs, self.crop_dims) + else: + # Take center crop. + center = np.array(self.image_dims) / 2.0 + crop = np.tile(center, (1, 2))[0] + np.concatenate([ + -self.crop_dims / 2.0, + self.crop_dims / 2.0 + ]) + inputs = inputs[:, crop[0]:crop[2], crop[1]:crop[3], :] + + # Classify + caffe_in = self.preprocess(self.inputs[0], inputs) + out = self.forward_all(**{self.inputs[0]: caffe_in}) + predictions = out[self.outputs[0]].squeeze(axis=(2,3)) + + # For oversampling, average predictions across crops. + if oversample: + predictions = predictions.reshape((len(predictions) / 10, 10, -1)) + predictions = predictions.mean(1) + + return predictions diff --git a/python/caffe/imagenet/__init__.py b/python/caffe/imagenet/__init__.py deleted file mode 100644 index 88cd447..0000000 --- a/python/caffe/imagenet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .wrapper import * diff --git a/python/caffe/imagenet/wrapper.py b/python/caffe/imagenet/wrapper.py deleted file mode 100644 index dd505e4..0000000 --- a/python/caffe/imagenet/wrapper.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python -"""wrapper.py implements an end-to-end wrapper that classifies an image read -from disk, using the imagenet classifier. -""" - -import numpy as np -import os -from skimage import io -from skimage import transform - -import caffe - -IMAGE_DIM = 256 -CROPPED_DIM = 227 - -# Load the imagenet mean file -IMAGENET_MEAN = np.load( - os.path.join(os.path.dirname(__file__), 'ilsvrc_2012_mean.npy')) - - -def oversample(image, center_only=False): - """ - Oversamples an image. Currently the indices are hard coded to the - 4 corners and the center of the image, as well as their flipped ones, - a total of 10 images. - - Input: - image: an image of size (256 x 256 x 3) and has data type uint8. - center_only: if True, only return the center image. - Output: - images: the output of size (10 x 3 x 227 x 227) - """ - indices = [0, IMAGE_DIM - CROPPED_DIM] - center = int(indices[1] / 2) - if center_only: - return np.ascontiguousarray( - image[np.newaxis, :, center:center + CROPPED_DIM, - center:center + CROPPED_DIM], - dtype=np.float32) - else: - images = np.empty((10, 3, CROPPED_DIM, CROPPED_DIM), dtype=np.float32) - curr = 0 - for i in indices: - for j in indices: - images[curr] = image[:, i:i + CROPPED_DIM, j:j + CROPPED_DIM] - curr += 1 - images[4] = image[:, center:center + CROPPED_DIM, - center:center + CROPPED_DIM] - # flipped version - images[5:] = images[:5, :, :, ::-1] - return images - - -def prepare_image(filename, center_only=False): - img = io.imread(filename) - if img.ndim == 2: - img = np.tile(img[:, :, np.newaxis], (1, 1, 3)) - elif img.shape[2] == 4: - img = img[:, :, :3] - # Resize, convert to BGR, and permute axes to caffe order - img_reshape = (transform.resize(img, (IMAGE_DIM,IMAGE_DIM)) * 255)[:, :, ::-1] - img_reshape = img_reshape.swapaxes(1, 2).swapaxes(0, 1) - # subtract main - img_reshape -= IMAGENET_MEAN - return oversample(img_reshape, center_only) - - -class ImageNetClassifier(object): - """ - The ImageNetClassifier is a wrapper class to perform easier deployment - of models trained on imagenet. - """ - def __init__(self, model_def_file, pretrained_model, center_only=False, - num_output=1000): - if center_only: - num = 1 - else: - num = 10 - self.caffenet = caffe.Net(model_def_file, pretrained_model) - self._output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)] - self._center_only = center_only - - def predict(self, filename): - input_blob = [prepare_image(filename, self._center_only)] - self.caffenet.Forward(input_blob, self._output_blobs) - return self._output_blobs[0].mean(0).flatten() - - -def main(argv): - """ - The main function will carry out classification. - """ - import gflags - import glob - import time - gflags.DEFINE_string("root", "", "The folder that contains images.") - gflags.DEFINE_string("ext", "JPEG", "The image extension.") - gflags.DEFINE_string("model_def", "", "The model definition file.") - gflags.DEFINE_string("pretrained_model", "", "The pretrained model.") - gflags.DEFINE_string("output", "", "The output numpy file.") - gflags.DEFINE_boolean("gpu", True, "use gpu for computation") - FLAGS = gflags.FLAGS - FLAGS(argv) - - net = ImageNetClassifier(FLAGS.model_def, FLAGS.pretrained_model) - - if FLAGS.gpu: - print 'Use gpu.' - net.caffenet.set_mode_gpu() - - files = glob.glob(os.path.join(FLAGS.root, "*." + FLAGS.ext)) - files.sort() - print 'A total of %d files' % len(files) - output = np.empty((len(files), net._output_blobs[0].shape[1]), - dtype=np.float32) - start = time.time() - for i, f in enumerate(files): - output[i] = net.predict(f) - if i % 1000 == 0 and i > 0: - print 'Processed %d files, elapsed %.2f s' % (i, time.time() - start) - # Finally, write the results - np.save(FLAGS.output, output) - print 'Done. Saved to %s.' % FLAGS.output - - -if __name__ == "__main__": - import sys - main(sys.argv) diff --git a/python/classify.py b/python/classify.py new file mode 100755 index 0000000..fdaeeb0 --- /dev/null +++ b/python/classify.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +""" +classify.py is an out-of-the-box image classifer callable from the command line. + +By default it configures and runs the Caffe reference ImageNet model. +""" +import numpy as np +import os +import sys +import argparse +import glob +import time + +import caffe + + +def main(argv): + pycaffe_dir = os.path.dirname(__file__) + + parser = argparse.ArgumentParser() + # Required arguments: input and output files. + parser.add_argument( + "input_file", + help="Input image, directory, or npy." + ) + parser.add_argument( + "output_file", + help="Output npy filename." + ) + # Optional arguments. + parser.add_argument( + "--model_def", + default=os.path.join(pycaffe_dir, + "../examples/imagenet/imagenet_deploy.prototxt"), + help="Model definition file." + ) + parser.add_argument( + "--pretrained_model", + default=os.path.join(pycaffe_dir, + "../examples/imagenet/caffe_reference_imagenet_model"), + help="Trained model weights file." + ) + parser.add_argument( + "--gpu", + action='store_true', + help="Switch for gpu computation." + ) + parser.add_argument( + "--center_only", + action='store_true', + help="Switch for prediction from center crop alone instead of " + + "averaging predictions across crops (default)." + ) + parser.add_argument( + "--images_dim", + default='256,256', + help="Canonical 'height,width' dimensions of input images." + ) + parser.add_argument( + "--mean_file", + default=os.path.join(pycaffe_dir, + 'caffe/imagenet/ilsvrc_2012_mean.npy'), + help="Data set image mean of H x W x K dimensions (numpy array). " + + "Set to '' for no mean subtraction." + ) + parser.add_argument( + "--input_scale", + type=float, + default=255, + help="Multiply input features by this scale before input to net" + ) + parser.add_argument( + "--channel_swap", + default='2,1,0', + help="Order to permute input channels. The default converts " + + "RGB -> BGR since BGR is the Caffe default by way of OpenCV." + + ) + parser.add_argument( + "--ext", + default='jpg', + help="Image file extension to take as input when a directory " + + "is given as the input file." + ) + args = parser.parse_args() + + image_dims = [int(s) for s in args.images_dim.split(',')] + channel_swap = [int(s) for s in args.channel_swap.split(',')] + + # Make classifier. + classifier = caffe.Classifier(args.model_def, args.pretrained_model, + image_dims=image_dims, gpu=args.gpu, mean_file=args.mean_file, + input_scale=args.input_scale, channel_swap=channel_swap) + + if args.gpu: + print 'GPU mode' + + # Load numpy array (.npy), directory glob (*.jpg), or image file. + args.input_file = os.path.expanduser(args.input_file) + if args.input_file.endswith('npy'): + inputs = np.load(args.input_file) + elif os.path.isdir(args.input_file): + inputs =[caffe.io.load_image(im_f) + for im_f in glob.glob(args.input_file + '/*.' + args.ext)] + else: + inputs = [caffe.io.load_image(args.input_file)] + + print "Classifying %d inputs." % len(inputs) + + # Classify. + start = time.time() + predictions = classifier.predict(inputs, not args.center_only) + print "Done in %.2f s." % (time.time() - start) + + # Save + np.save(args.output_file, predictions) + + +if __name__ == '__main__': + main(sys.argv) -- 2.7.4