From 469eeea37065c8c7fe82e4117587120711750151 Mon Sep 17 00:00:00 2001 From: "marina.kolpakova" Date: Mon, 21 Jan 2013 02:36:23 +0400 Subject: [PATCH] add ROC estimation in the same way as Dallar's matlab toolbox does --- apps/sft/misk/roc_test.py | 42 +++++++++++++++------------ apps/sft/misk/sft.py | 73 ++++++++++++++++++++++++----------------------- 2 files changed, 61 insertions(+), 54 deletions(-) diff --git a/apps/sft/misk/roc_test.py b/apps/sft/misk/roc_test.py index 5531486..3dbf94d 100755 --- a/apps/sft/misk/roc_test.py +++ b/apps/sft/misk/roc_test.py @@ -34,43 +34,47 @@ if __name__ == "__main__": args = parser.parse_args() + # parse annotations samples = call_parser(args.anttn_format, args.annotations) - - # where we use nms cv::SCascade::DOLLAR == 2 - cascade = cv2.SCascade(args.min_scale, args.max_scale, args.nscales, 2) - xml = cv2.FileStorage(args.cascade, 0) - dom = xml.getFirstTopLevelNode() - assert cascade.load(dom) - + cascade = sft.cascade(args.min_scale, args.max_scale, args.nscales, args.cascade) pattern = args.input camera = cv2.VideoCapture(pattern) - frame = 0 + # for plotting over dataset + nannotated = 0 + nframes = 0 + + confidenses = [] + tp = [] + while True: ret, img = camera.read() if not ret: break; - name = pattern % (frame,) - qq = pattern.format(frame) + name = pattern % (nframes,) _, tail = os.path.split(name) boxes = samples[tail] boxes = sft.norm_acpect_ratio(boxes, 0.5) - frame = frame + 1 + nannotated = nannotated + len(boxes) + nframes = nframes + 1 rects, confs = cascade.detect(img, rois = None) + if confs is None: + continue + dts = sft.convert2detections(rects, confs) - sft.draw_dt(img, dts, bgr["green"]) - fp, fn = sft.match(boxes, dts) - print "fp and fn", fp, fn + confs = confs.tolist()[0] + confs.sort(lambda x, y : -1 if (x - y) > 0 else 1) + confidenses = confidenses + confs + matched = sft.match(boxes, dts) + tp = tp + matched - sft.draw_rects(img, boxes, bgr["blue"], lambda x, y : y) - cv2.imshow("result", img); - if (cv2.waitKey (0) == 27): - break; + print nframes, nannotated - # sft.plot_curve() \ No newline at end of file + fppi, miss_rate = sft.computeROC(confidenses, tp, nannotated, nframes) + sft.plotLogLog(fppi, miss_rate) \ No newline at end of file diff --git a/apps/sft/misk/sft.py b/apps/sft/misk/sft.py index b0ee081..cd6ca0d 100644 --- a/apps/sft/misk/sft.py +++ b/apps/sft/misk/sft.py @@ -19,6 +19,34 @@ def convert2detections(rects, confs, crop_factor = 0.125): return dts +def cascade(min_scale, max_scale, nscales, f): + # where we use nms cv::SCascade::DOLLAR == 2 + c = cv2.SCascade(min_scale, max_scale, nscales, 2) + xml = cv2.FileStorage(f, 0) + dom = xml.getFirstTopLevelNode() + assert c.load(dom) + return c + +def cumsum(n): + cum = [] + y = 0 + for i in n: + y += i + cum.append(y) + return cum + +def computeROC(confidenses, tp, nannotated, nframes): + confidenses, tp = zip(*sorted(zip(confidenses, tp), reverse = True)) + + fp = [(1 - x) for x in tp] + fp = cumsum(fp) + tp = cumsum(tp) + miss_rate = [(1 - x / (nannotated + 0.000001)) for x in tp] + fppi = [x / float(nframes) for x in fp] + + return fppi, miss_rate + + def crop_rect(rect, factor): val_x = factor * float(rect[2]) val_y = factor * float(rect[3]) @@ -27,29 +55,28 @@ def crop_rect(rect, factor): # -def plot_curve(): +def plotLogLog(fppi, miss_rate): fig, ax = plt.subplots() fig.canvas.draw() - x = np.linspace(pow(10,-4), pow(10,1), 101) - y = 1 - x - - plt.semilogy(x,y,color='m',linewidth=2) plt.xlabel("fppi") plt.ylabel("miss rate") plt.title("ROC curve Bahnhof") - plt.yticks( [0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.64, 0.80]) - ylabels = [item.get_text() for item in ax.get_yticklabels()] - ax.set_yticklabels( ylabels ) + # plt.yticks( [0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.64, 0.80]) + # ylabels = [item.get_text() for item in ax.get_yticklabels()] + # ax.set_yticklabels( ylabels ) plt.grid(True) - # plt.xticks( [pow(10, -4), pow(10, -3), pow(10, -2), pow(10, -1), pow(10, 0), pow(10, 0)]) + # plt.xticks( [pow(10, -4), pow(10, -3), pow(10, -2), pow(10, -1), pow(10, 0), pow(10, 1)]) # xlabels = [item.get_text() for item in ax.get_xticklabels()] # ax.set_xticklabels( xlabels ) - plt.xscale('log') + plt.yscale('log') + + plt.semilogy(fppi, miss_rate, color='m', linewidth=2) + plt.show() def draw_rects(img, rects, color, l = lambda x, y : x + y): @@ -81,7 +108,6 @@ class Detection: # we use rect-stype for dt and box style for gt. ToDo: fix it def overlap(self, b): - print self.bb, "vs", b a = self.bb w = min( a[0] + a[2], b[2]) - max(a[0], b[0]); h = min( a[1] + a[3], b[3]) - max(a[1], b[1]); @@ -135,39 +161,16 @@ def norm_acpect_ratio(boxes, ratio): def match(gts, dts): - - for dt in dts: - print dt.bb, - - print - - for gt in gts: - print gt - - # Cartesian product for each detection BB_dt with each BB_gt overlaps = [[dt.overlap(gt) for gt in gts]for dt in dts] - print overlaps matches_gt = [0]*len(gts) - print matches_gt - matches_dt = [0]*len(dts) - print matches_dt for idx, row in enumerate(overlaps): - print idx, row - imax = row.index(max(row)) if (matches_gt[imax] == 0 and row[imax] > 0.5): matches_gt[imax] = 1 matches_dt[idx] = 1 - - print matches_gt - print matches_dt - - fp = sum(1 for x in matches_dt if x == 0) - fn = sum(1 for x in matches_gt if x == 0) - - return fp, fn \ No newline at end of file + return matches_dt \ No newline at end of file -- 2.7.4