[pycaffe] align web demo with #1728 and #1902
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 7 Mar 2015 03:36:29 +0000 (19:36 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 8 Mar 2015 08:16:27 +0000 (00:16 -0800)
examples/web_demo/app.py

index e456526..208b1da 100644 (file)
@@ -10,12 +10,13 @@ import tornado.wsgi
 import tornado.httpserver
 import numpy as np
 import pandas as pd
-from PIL import Image as PILImage
+import Image
 import cStringIO as StringIO
 import urllib
-import caffe
 import exifutil
 
+import caffe
+
 REPO_DIRNAME = os.path.abspath(os.path.dirname(__file__) + '/../..')
 UPLOAD_FOLDER = '/tmp/caffe_demos_uploads'
 ALLOWED_IMAGE_EXTENSIONS = set(['png', 'bmp', 'jpg', 'jpe', 'jpeg', 'gif'])
@@ -80,7 +81,7 @@ def classify_upload():
 
 def embed_image_html(image):
     """Creates an image embedded in HTML base64 format."""
-    image_pil = PILImage.fromarray((255 * image).astype('uint8'))
+    image_pil = Image.fromarray((255 * image).astype('uint8'))
     image_pil = image_pil.resize((256, 256))
     string_buf = StringIO.StringIO()
     image_pil.save(string_buf, format='png')
@@ -114,15 +115,18 @@ class ImagenetClassifier(object):
                 "File for {} is missing. Should be at: {}".format(key, val))
     default_args['image_dim'] = 256
     default_args['raw_scale'] = 255.
-    default_args['gpu_mode'] = False
 
     def __init__(self, model_def_file, pretrained_model_file, mean_file,
                  raw_scale, class_labels_file, bet_file, image_dim, gpu_mode):
         logging.info('Loading net and associated files...')
+        if gpu_mode:
+            caffe.set_mode_gpu()
+        else:
+            caffe.set_mode_cpu()
         self.net = caffe.Classifier(
             model_def_file, pretrained_model_file,
             image_dims=(image_dim, image_dim), raw_scale=raw_scale,
-            mean=np.load(mean_file), channel_swap=(2, 1, 0), gpu=gpu_mode
+            mean=np.load(mean_file).mean(1).mean(1), channel_swap=(2, 1, 0)
         )
 
         with open(class_labels_file) as f: