From 6f7034c886f730ea24658bfa5e1e8dc21f9da688 Mon Sep 17 00:00:00 2001 From: Kwang Son Date: Mon, 12 Apr 2021 10:46:38 +0900 Subject: [PATCH] [MVQA] Add barcode and face detection Change-Id: Icb8be5d0eed217b5b06861cd1c21d61030ad79a2 Signed-off-by: Kwang Son --- script/mvqa/db.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++++---- script/mvqa/run.py | 10 ++-- 2 files changed, 130 insertions(+), 12 deletions(-) diff --git a/script/mvqa/db.py b/script/mvqa/db.py index 6c58ffd..2394f0d 100644 --- a/script/mvqa/db.py +++ b/script/mvqa/db.py @@ -1,16 +1,17 @@ import sqlite3 import os +import glob +import subprocess DB_PATH = 'mvqa.db' +WIDER_FACE_PREFIX = '/opt/mvdata/WIDER_face/WIDER_val/images/' -dataset = [('Artelab', 1000), - ('ILSVRC2012', 1000), - ('COCO', 500), - ('openimages', 500), +dataset = [('Artelab', 430), + ('WIDER_face', 3226), ] -testset = [('facedetection - CPU',), - ('facedetection - GPU',), +testset = [('barcode detect',), + ('facedetection(cascade) - CPU',), ] @@ -51,7 +52,10 @@ def create(): '''CREATE TABLE 'performance'( id integer PRIMARY KEY, benchmark_id integer NOT NULL, - data text)''') + millisecond integer NOT NULL, + data text NOT NULL, + label text NOT NULL, + note text)''') conn.commit() conn.close() @@ -92,12 +96,122 @@ def insert_benchmark(dataset_id, testset_id, target): return id -def insert_performance(bench_id, text): +def insert_performance(bench_id, time, data, label, text): if not exist(): create() conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute( - 'INSERT INTO performance VALUES (null,?,?)', (bench_id, text)) + 'INSERT INTO performance VALUES (null,?,?,?,?,?)', (bench_id, time, data, label, text)) conn.commit() conn.close() + + +def get_all_item(dataset_id): + items = [] + if (dataset_id is 1): + images = glob.glob('/opt/mvdata/ArteLab1D/BarcodeDatasets/*/*.jpg') + for image in images: + items.append((image, image + '.txt')) + return items + elif (dataset_id is 2): + with open('/opt/mvdata/WIDER_face/wider_face_split/wider_face_val_bbx_gt.txt') as f: + bbox = f.readlines() + idx = 0 + while(idx < len(bbox)): + mdata = bbox[idx][:-1] + idx += 1 + cnt = int(bbox[idx]) + idx += 1 + mlabel = bbox[idx:idx+cnt] + slabel = '' + for ele in mlabel: + slabel += ele + idx += cnt + items.append((mdata, slabel)) + return items + else: + raise NotImplementedError + + +class Session: + def __init__(self, dataset_id, testset_id): + self.dataset_id = dataset_id + self.testset_id = testset_id + os.system('sdb shell rm -rf /opt/mvdata') + + def load(self): + pass + + def run(self, item): + pass + + def clean(self, item): + pass + + def verify(self, result, label): + pass + + +class Barcode(Session): + def __init__(self, dataset_id, testset_id): + super().__init__(dataset_id, testset_id) + + def load(self, item): + os.system('sdb push ' + item + ' ' + item) + + def run(self, item): + return subprocess.check_output('sdb shell mv_barcode_assessment ' + item, shell=True) + + def clean(self, item): + os.system('sdb shell rm ' + item) + + def verify(self, result, label): + reform = result.decode('utf-8').split() + with open(label) as f: + ans = f.read().split() + print('result : {} label : {} reform : {} ans : {}'.format( + result, label, reform, ans)) + if(len(ans) != int(reform[0])): + return int(reform[-1][:-2]), 'Fail' + else: + for idx, code in enumerate(ans): + if(reform[idx + 1] != code): + return int(reform[-1][:-2]), 'Fail' + return int(reform[-1][:-2]), 'Pass' + + +class FaceCascadeDetection(Session): + def __init__(self, dataset_id, testset_id): + super().__init__(dataset_id, testset_id) + dirs = glob.glob('/opt/mvdata/WIDER_face/WIDER_val/images/*') + for dir in dirs: + os.system('sdb shell mkdir -p ' + dir) + + def load(self, item): + os.system('sdb push ' + WIDER_FACE_PREFIX + + item + ' ' + WIDER_FACE_PREFIX + item) + + def run(self, item): + return subprocess.check_output('sdb shell mv_face_assessment ' + WIDER_FACE_PREFIX + item + ' 0', shell=True) + + def clean(self, item): + os.system('sdb shell rm ' + WIDER_FACE_PREFIX + item) + + def verify(self, result, label): + """ + https://github.com/Cartucho/mAP + """ + rv = result.decode('utf-8') + reform = rv.split() + task_time = reform[-1][:-2] + return (task_time, rv) + + +def get_session(dataset_id, testset_id): + if(dataset_id == 1): + return Barcode(dataset_id, testset_id) + elif(dataset_id == 2): + return FaceCascadeDetection(dataset_id, testset_id) + else: + raise NotImplementedError diff --git a/script/mvqa/run.py b/script/mvqa/run.py index 8742baa..80254d9 100644 --- a/script/mvqa/run.py +++ b/script/mvqa/run.py @@ -13,6 +13,10 @@ def run_bench(dataset_id, testset_id): # check dataset is fit with testset # check dataset is ready bench_id = db.insert_benchmark(dataset_id, testset_id, target_type) - # insert result - for _ in range(5): - db.insert_performance(bench_id, "Pass, 1.5MB, 10s, ...") + sess = db.get_session(dataset_id, testset_id) + for data, label in db.get_all_item(dataset_id): + sess.load(data) + result = sess.run(data) + sess.clean(data) + ms, save = sess.verify(result, label) + # db.insert_performance(bench_id, ms, data, label, save) -- 2.7.4