add plotting function (matplotlib required)
authormarina.kolpakova <marina.kolpakova@itseez.com>
Fri, 18 Jan 2013 08:22:03 +0000 (12:22 +0400)
committermarina.kolpakova <marina.kolpakova@itseez.com>
Fri, 1 Feb 2013 10:35:28 +0000 (14:35 +0400)
apps/sft/misk/roc_test.py
apps/sft/misk/sft.py

index 7aae82e..2c5386f 100755 (executable)
@@ -63,5 +63,7 @@ if __name__ == "__main__":
             sft.draw_rects(img, rects[0], (0, 255, 0))
 
         cv2.imshow("result", img);
-        if (cv2.waitKey (5) != -1):
-            break;
\ No newline at end of file
+        if (cv2.waitKey (0) == 27):
+            break;
+
+    # sft.plot_curve()
\ No newline at end of file
index 3e9ee88..2d3782c 100644 (file)
@@ -2,33 +2,67 @@
 
 import cv2, re, glob
 import numpy as np
+import matplotlib.pyplot as plt
+
+def plot_curve():
+
+    fig, ax = plt.subplots()
+    fig.canvas.draw()
+
+    x = np.linspace(pow(10,-4), pow(10,1), 101)
+    y = 1 - x
+
+    plt.semilogy(x,y,color='m',linewidth=2)
+    plt.xlabel("fppi")
+    plt.ylabel("miss rate")
+    plt.title("ROC curve Bahnhof")
+
+    plt.yticks( [0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.64, 0.80])
+    ylabels = [item.get_text() for item in ax.get_yticklabels()]
+    ax.set_yticklabels( ylabels )
+    plt.grid(True)
+
+    # plt.xticks( [pow(10, -4), pow(10, -3), pow(10, -2), pow(10, -1), pow(10, 0), pow(10, 0)])
+    # xlabels = [item.get_text() for item in ax.get_xticklabels()]
+    # ax.set_xticklabels( xlabels )
+
+    plt.xscale('log')
+    plt.show()
+
+
 
 def draw_rects(img, rects, color, l = lambda x, y : x + y):
     if rects is not None:
         for x1, y1, x2, y2 in rects:
             cv2.rectangle(img, (x1, y1), (l(x1, x2), l(y1, y2)), color, 2)
 
-class Sample:
-    def __init__(self, bbs, img):
-        self.image = img
-        self.bbs = bb
+class Annotation:
+    def __init__(self, bb):
+        self.bb = bb
 
 class Detection:
     def __init__(self, bb, conf):
         self.bb = bb
         self.conf = conf
+        self.matched = False
+
+    # def crop(self):
+    #     rel_scale = self.bb[1] / 128
+
 
     # we use rect-stype for dt and box style for gt. ToDo: fix it
     def overlap(self, b):
         a = self.bb
-        print "HERE:", a, b
-        w = min( a[0] + a[2], b[0] + b[2]) - max(a[0], b[0]);
-        h = min( a[1] + a[3], b[1] + b[3]) - max(a[1], b[1]);
+        w = min( a[0] + a[2], b[2]) - max(a[0], b[0]);
+        h = min( a[1] + a[3], b[3]) - max(a[1], b[1]);
 
         cross_area = 0.0 if (w < 0 or h < 0) else float(w * h)
         union_area = (a[2] * a[3]) + ((b[2] - b[0]) * (b[3] - b[1])) - cross_area;
 
-        return cross_area / union_area;
+        return cross_area / union_area
+
+    def mark_matched(self):
+        self.matched = True
 
 
 def parse_inria(ipath, f):
@@ -64,11 +98,37 @@ def match(gts, rects, confs):
     if rects is None:
         return 0
 
+    fp = 0
+    fn = 0
+
     dts = zip(*[rects.tolist(), confs.tolist()])
     dts = zip(dts[0][0], dts[0][1])
     dts = [Detection(r,c) for r, c in dts]
 
-    for dt in dts:
-        for gt in gts:
+    for gt in gts:
+
+        # exclude small
+        if gt[2] - gt[0] < 27:
+            continue
+
+        matched = False
+
+        for dt in dts:
+            # dt.crop()
             overlap =  dt.overlap(gt)
-            print overlap
\ No newline at end of file
+            print dt.bb,  "vs", gt, overlap
+            if overlap > 0.5:
+                dt.mark_matched()
+                matched = True
+                print "matched ", dt.bb, gt
+
+        if not matched:
+            fn = fn + 1
+
+    print "fn", fn
+
+    for dt in dts:
+        if not dt.matched:
+            fp = fp + 1
+
+    print "fp", fp