wrapper update
authorYangqing Jia <jiayq84@gmail.com>
Fri, 15 Nov 2013 16:02:48 +0000 (08:02 -0800)
committerYangqing Jia <jiayq84@gmail.com>
Fri, 15 Nov 2013 16:02:48 +0000 (08:02 -0800)
python/caffe/imagenet/wrapper.py

index ec1d409..8fc22d5 100644 (file)
@@ -70,14 +70,15 @@ class ImageNetClassifier(object):
   """The ImageNetClassifier is a wrapper class to perform easier deployment
   of models trained on imagenet.
   """
-  def __init__(self, model_def_file, pretrained_model, center_only = False):
+  def __init__(self, model_def_file, pretrained_model, center_only = False,
+      num_output=1000):
     if center_only:
       num = 1
     else:
       num = 10
     self.caffenet = caffe.CaffeNet(model_def_file, pretrained_model,
         [num, 3, CROPPED_DIM, CROPPED_DIM])
-    self._output_blobs = [np.empty((num, 1000, 1, 1), dtype=np.float32)]
+    self._output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)]
     self._center_only = center_only
 
   def predict(self, filename):
@@ -106,7 +107,8 @@ def main(argv):
   files = glob.glob(os.path.join(FLAGS.root, "*." + FLAGS.ext))
   files.sort()
   print 'A total of %d files' % len(files)
-  output = np.empty((len(files), 1000), dtype=np.float32)
+  output = np.empty((len(files), self._output_blobs[0].shape[1]),
+      dtype=np.float32)
   start = time.time()
   for i, f in enumerate(files):
     output[i] = net.predict(f)