Bugfix for path issues (#3038)
authorSiju <sijusamuel@gmail.com>
Thu, 18 Apr 2019 22:20:11 +0000 (03:50 +0530)
committerTianqi Chen <tqchen@users.noreply.github.com>
Thu, 18 Apr 2019 22:20:11 +0000 (15:20 -0700)
nnvm/python/nnvm/testing/yolo_detection.py
nnvm/tutorials/from_darknet.py

index 9ecb49ae04f0f78787e3e3ba8062fc5689807238..bdf9efe62de42ebdcd4fada43028055413bfc7d7 100644 (file)
@@ -165,7 +165,7 @@ def do_nms_sort(dets, classes, thresh):
                 if _box_iou(a, b) > thresh:
                     dets[j]['prob'][k] = 0
 
-def draw_detections(im, dets, thresh, names, classes):
+def draw_detections(font_path, im, dets, thresh, names, classes):
     "Draw the markings around the detected region"
     for det in dets:
         labelstr = []
@@ -198,7 +198,7 @@ def draw_detections(im, dets, thresh, names, classes):
             if bot > imh-1:
                 bot = imh-1
             _draw_box_width(im, left, top, right, bot, width, red, green, blue)
-            label = _get_label(''.join(labelstr), rgb)
+            label = _get_label(font_path, ''.join(labelstr), rgb)
             _draw_label(im, top + width, left, label, rgb)
 
 def _get_pixel(im, x, y, c):
@@ -223,7 +223,7 @@ def _draw_label(im, r, c, label, rgb):
                         val = _get_pixel(label, i, j, k)
                         _set_pixel(im, i+c, j+r, k, val)#rgb[k] * val)
 
-def _get_label(labelstr, rgb):
+def _get_label(font_path, labelstr, rgb):
     from PIL import Image
     from PIL import ImageDraw
     from PIL import ImageFont
@@ -231,7 +231,7 @@ def _get_label(labelstr, rgb):
     text = labelstr
     colorText = "black"
     testDraw = ImageDraw.Draw(Image.new('RGB', (1, 1)))
-    font = ImageFont.truetype("arial.ttf", 25)
+    font = ImageFont.truetype(font_path, 25)
     width, height = testDraw.textsize(labelstr, font=font)
     img = Image.new('RGB', (width, height), color=(int(rgb[0]*255), int(rgb[1]*255),
                                                    int(rgb[2]*255)))
index 607af103862886ec577a65c9975776729ea03c5d..857ef46015cd00a7665a5f133d30fd11ea801fd3 100644 (file)
@@ -153,7 +153,7 @@ elif MODEL_NAME == 'yolov3':
 # do the detection and bring up the bounding boxes
 thresh = 0.5
 nms_thresh = 0.45
-img = nnvm.testing.darknet.load_image_color(test_image)
+img = nnvm.testing.darknet.load_image_color(img_path)
 _, im_h, im_w = img.shape
 dets = nnvm.testing.yolo_detection.fill_network_boxes((netw, neth), (im_w, im_h), thresh,
                                                       1, tvm_out)
@@ -172,6 +172,6 @@ with open(coco_path) as f:
 
 names = [x.strip() for x in content]
 
-nnvm.testing.yolo_detection.draw_detections(img, dets, thresh, names, last_layer.classes)
+nnvm.testing.yolo_detection.draw_detections(font_path, img, dets, thresh, names, last_layer.classes)
 plt.imshow(img.transpose(1, 2, 0))
 plt.show()