From 11f3927c58e6ad54b33b8e4cd43a8eae8af52849 Mon Sep 17 00:00:00 2001 From: "marina.kolpakova" Date: Mon, 21 Jan 2013 15:53:25 +0400 Subject: [PATCH] allow multiple detectors --- apps/sft/misk/roc_test.py | 72 ++++++++++++++++++++++++++--------------------- apps/sft/misk/sft.py | 19 ++++++------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/apps/sft/misk/roc_test.py b/apps/sft/misk/roc_test.py index 3dbf94d..8682119 100755 --- a/apps/sft/misk/roc_test.py +++ b/apps/sft/misk/roc_test.py @@ -7,6 +7,8 @@ import sys, os, os.path, glob, math, cv2 from datetime import datetime import numpy +plot_colors = ['b', 'r', 'g', 'c', 'm'] + # "key" : ( b, g, r) bgr = { "red" : ( 0, 0, 255), "green" : ( 0, 255, 0), @@ -19,7 +21,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description = 'Plot ROC curve using Caltech mathod of per image detection performance estimation.') # positional - parser.add_argument("cascade", help = "Path to the tested detector.") + parser.add_argument("cascade", help = "Path to the tested detector.", nargs='+') parser.add_argument("input", help = "Image sequence pattern.") parser.add_argument("annotations", help = "Path to the annotations.") @@ -34,47 +36,53 @@ if __name__ == "__main__": args = parser.parse_args() - # parse annotations + print args.cascade + # # parse annotations + sft.initPlot() samples = call_parser(args.anttn_format, args.annotations) - cascade = sft.cascade(args.min_scale, args.max_scale, args.nscales, args.cascade) - pattern = args.input - camera = cv2.VideoCapture(pattern) + for idx, each in enumerate(args.cascade): + print each + cascade = sft.cascade(args.min_scale, args.max_scale, args.nscales, each) + pattern = args.input + camera = cv2.VideoCapture(pattern) + + # for plotting over dataset + nannotated = 0 + nframes = 0 - # for plotting over dataset - nannotated = 0 - nframes = 0 + confidenses = [] + tp = [] - confidenses = [] - tp = [] + while True: + ret, img = camera.read() + if not ret: + break; - while True: - ret, img = camera.read() - if not ret: - break; + name = pattern % (nframes,) + _, tail = os.path.split(name) - name = pattern % (nframes,) - _, tail = os.path.split(name) + boxes = samples[tail] + boxes = sft.norm_acpect_ratio(boxes, 0.5) - boxes = samples[tail] - boxes = sft.norm_acpect_ratio(boxes, 0.5) + nannotated = nannotated + len(boxes) + nframes = nframes + 1 + rects, confs = cascade.detect(img, rois = None) - nannotated = nannotated + len(boxes) - nframes = nframes + 1 - rects, confs = cascade.detect(img, rois = None) + if confs is None: + continue - if confs is None: - continue + dts = sft.convert2detections(rects, confs) - dts = sft.convert2detections(rects, confs) + confs = confs.tolist()[0] + confs.sort(lambda x, y : -1 if (x - y) > 0 else 1) + confidenses = confidenses + confs - confs = confs.tolist()[0] - confs.sort(lambda x, y : -1 if (x - y) > 0 else 1) - confidenses = confidenses + confs + matched = sft.match(boxes, dts) + tp = tp + matched - matched = sft.match(boxes, dts) - tp = tp + matched + print nframes, nannotated - print nframes, nannotated + fppi, miss_rate = sft.computeROC(confidenses, tp, nannotated, nframes) + sft.plotLogLog(fppi, miss_rate, plot_colors[idx]) - fppi, miss_rate = sft.computeROC(confidenses, tp, nannotated, nframes) - sft.plotLogLog(fppi, miss_rate) \ No newline at end of file + sft.showPlot("roc_curve.png") \ No newline at end of file diff --git a/apps/sft/misk/sft.py b/apps/sft/misk/sft.py index cd6ca0d..8c7f4b4 100644 --- a/apps/sft/misk/sft.py +++ b/apps/sft/misk/sft.py @@ -55,7 +55,7 @@ def crop_rect(rect, factor): # -def plotLogLog(fppi, miss_rate): +def initPlot(): fig, ax = plt.subplots() fig.canvas.draw() @@ -63,22 +63,19 @@ def plotLogLog(fppi, miss_rate): plt.xlabel("fppi") plt.ylabel("miss rate") plt.title("ROC curve Bahnhof") - - # plt.yticks( [0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.64, 0.80]) - # ylabels = [item.get_text() for item in ax.get_yticklabels()] - # ax.set_yticklabels( ylabels ) plt.grid(True) - - # plt.xticks( [pow(10, -4), pow(10, -3), pow(10, -2), pow(10, -1), pow(10, 0), pow(10, 1)]) - # xlabels = [item.get_text() for item in ax.get_xticklabels()] - # ax.set_xticklabels( xlabels ) plt.xscale('log') plt.yscale('log') - plt.semilogy(fppi, miss_rate, color='m', linewidth=2) - +def showPlot(name): + plt.savefig(name) plt.show() + +def plotLogLog(fppi, miss_rate, c): + plt.semilogy(fppi, miss_rate, color = c, linewidth = 2) + + def draw_rects(img, rects, color, l = lambda x, y : x + y): if rects is not None: for x1, y1, x2, y2 in rects: -- 2.7.4