"""The ImageNetClassifier is a wrapper class to perform easier deployment
of models trained on imagenet.
"""
- def __init__(self, model_def_file, pretrained_model, center_only = False):
+ def __init__(self, model_def_file, pretrained_model, center_only = False,
+ num_output=1000):
if center_only:
num = 1
else:
num = 10
self.caffenet = caffe.CaffeNet(model_def_file, pretrained_model,
[num, 3, CROPPED_DIM, CROPPED_DIM])
- self._output_blobs = [np.empty((num, 1000, 1, 1), dtype=np.float32)]
+ self._output_blobs = [np.empty((num, num_output, 1, 1), dtype=np.float32)]
self._center_only = center_only
def predict(self, filename):
files = glob.glob(os.path.join(FLAGS.root, "*." + FLAGS.ext))
files.sort()
print 'A total of %d files' % len(files)
- output = np.empty((len(files), 1000), dtype=np.float32)
+ output = np.empty((len(files), self._output_blobs[0].shape[1]),
+ dtype=np.float32)
start = time.time()
for i, f in enumerate(files):
output[i] = net.predict(f)