From f54e9779ef5b99d21df93b5df394d9aefa125c06 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Fri, 15 Nov 2013 08:02:48 -0800 Subject: [PATCH] wrapper update --- python/caffe/imagenet/wrapper.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/caffe/imagenet/wrapper.py b/python/caffe/imagenet/wrapper.py index ec1d409..8fc22d5 100644 --- a/python/caffe/imagenet/wrapper.py +++ b/python/caffe/imagenet/wrapper.py @@ -70,14 +70,15 @@ class ImageNetClassifier(object): """The ImageNetClassifier is a wrapper class to perform easier deployment of models trained on imagenet. """ - def __init__(self, model_def_file, pretrained_model, center_only = False): + def __init__(self, model_def_file, pretrained_model, center_only = False, + num_output=1000): if center_only: num = 1 else: num = 10 self.caffenet = caffe.CaffeNet(model_def_file, pretrained_model, [num, 3, CROPPED_DIM, CROPPED_DIM]) - self._output_blobs = [np.empty((num, 1000, 1, 1), dtype=np.float32)] + self._output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)] self._center_only = center_only def predict(self, filename): @@ -106,7 +107,8 @@ def main(argv): files = glob.glob(os.path.join(FLAGS.root, "*." + FLAGS.ext)) files.sort() print 'A total of %d files' % len(files) - output = np.empty((len(files), 1000), dtype=np.float32) + output = np.empty((len(files), self._output_blobs[0].shape[1]), + dtype=np.float32) start = time.time() for i, f in enumerate(files): output[i] = net.predict(f) -- 2.7.4