[MVQA] Add barcode and face detection 17/256717/1
authorKwang Son <k.son@samsung.com>
Mon, 12 Apr 2021 01:46:38 +0000 (10:46 +0900)
committerKwang Son <k.son@samsung.com>
Mon, 12 Apr 2021 01:46:38 +0000 (10:46 +0900)
Change-Id: Icb8be5d0eed217b5b06861cd1c21d61030ad79a2
Signed-off-by: Kwang Son <k.son@samsung.com>
script/mvqa/db.py
script/mvqa/run.py

index 6c58ffd..2394f0d 100644 (file)
@@ -1,16 +1,17 @@
 import sqlite3
 import os
+import glob
+import subprocess
 
 DB_PATH = 'mvqa.db'
+WIDER_FACE_PREFIX = '/opt/mvdata/WIDER_face/WIDER_val/images/'
 
-dataset = [('Artelab', 1000),
-           ('ILSVRC2012', 1000),
-           ('COCO', 500),
-           ('openimages', 500),
+dataset = [('Artelab', 430),
+           ('WIDER_face', 3226),
            ]
 
-testset = [('facedetection - CPU',),
-           ('facedetection - GPU',),
+testset = [('barcode detect',),
+           ('facedetection(cascade) - CPU',),
            ]
 
 
@@ -51,7 +52,10 @@ def create():
         '''CREATE TABLE 'performance'(
             id integer PRIMARY KEY,
             benchmark_id integer NOT NULL,
-            data text)''')
+            millisecond integer NOT NULL,
+            data text NOT NULL,
+            label text NOT NULL,
+            note text)''')
     conn.commit()
     conn.close()
 
@@ -92,12 +96,122 @@ def insert_benchmark(dataset_id, testset_id, target):
     return id
 
 
-def insert_performance(bench_id, text):
+def insert_performance(bench_id, time, data, label, text):
     if not exist():
         create()
     conn = sqlite3.connect(DB_PATH)
     c = conn.cursor()
     c.execute(
-        'INSERT INTO performance VALUES (null,?,?)', (bench_id, text))
+        'INSERT INTO performance VALUES (null,?,?,?,?,?)', (bench_id, time, data, label, text))
     conn.commit()
     conn.close()
+
+
+def get_all_item(dataset_id):
+    items = []
+    if (dataset_id is 1):
+        images = glob.glob('/opt/mvdata/ArteLab1D/BarcodeDatasets/*/*.jpg')
+        for image in images:
+            items.append((image, image + '.txt'))
+        return items
+    elif (dataset_id is 2):
+        with open('/opt/mvdata/WIDER_face/wider_face_split/wider_face_val_bbx_gt.txt') as f:
+            bbox = f.readlines()
+        idx = 0
+        while(idx < len(bbox)):
+            mdata = bbox[idx][:-1]
+            idx += 1
+            cnt = int(bbox[idx])
+            idx += 1
+            mlabel = bbox[idx:idx+cnt]
+            slabel = ''
+            for ele in mlabel:
+                slabel += ele
+            idx += cnt
+            items.append((mdata, slabel))
+        return items
+    else:
+        raise NotImplementedError
+
+
+class Session:
+    def __init__(self, dataset_id, testset_id):
+        self.dataset_id = dataset_id
+        self.testset_id = testset_id
+        os.system('sdb shell rm -rf /opt/mvdata')
+
+    def load(self):
+        pass
+
+    def run(self, item):
+        pass
+
+    def clean(self, item):
+        pass
+
+    def verify(self, result, label):
+        pass
+
+
+class Barcode(Session):
+    def __init__(self, dataset_id, testset_id):
+        super().__init__(dataset_id, testset_id)
+
+    def load(self, item):
+        os.system('sdb push ' + item + ' ' + item)
+
+    def run(self, item):
+        return subprocess.check_output('sdb shell mv_barcode_assessment ' + item, shell=True)
+
+    def clean(self, item):
+        os.system('sdb shell rm ' + item)
+
+    def verify(self, result, label):
+        reform = result.decode('utf-8').split()
+        with open(label) as f:
+            ans = f.read().split()
+        print('result : {} label : {} reform : {} ans : {}'.format(
+            result, label, reform, ans))
+        if(len(ans) != int(reform[0])):
+            return int(reform[-1][:-2]), 'Fail'
+        else:
+            for idx, code in enumerate(ans):
+                if(reform[idx + 1] != code):
+                    return int(reform[-1][:-2]), 'Fail'
+            return int(reform[-1][:-2]), 'Pass'
+
+
+class FaceCascadeDetection(Session):
+    def __init__(self, dataset_id, testset_id):
+        super().__init__(dataset_id, testset_id)
+        dirs = glob.glob('/opt/mvdata/WIDER_face/WIDER_val/images/*')
+        for dir in dirs:
+            os.system('sdb shell mkdir -p ' + dir)
+
+    def load(self, item):
+        os.system('sdb push ' + WIDER_FACE_PREFIX +
+                  item + ' ' + WIDER_FACE_PREFIX + item)
+
+    def run(self, item):
+        return subprocess.check_output('sdb shell mv_face_assessment ' + WIDER_FACE_PREFIX + item + ' 0', shell=True)
+
+    def clean(self, item):
+        os.system('sdb shell rm ' + WIDER_FACE_PREFIX + item)
+
+    def verify(self, result, label):
+        """
+        https://github.com/Cartucho/mAP
+        """
+        rv = result.decode('utf-8')
+        reform = rv.split()
+        task_time = reform[-1][:-2]
+        return (task_time, rv)
+
+
+def get_session(dataset_id, testset_id):
+    if(dataset_id == 1):
+        return Barcode(dataset_id, testset_id)
+    elif(dataset_id == 2):
+        return FaceCascadeDetection(dataset_id, testset_id)
+    else:
+        raise NotImplementedError
index 8742baa..80254d9 100644 (file)
@@ -13,6 +13,10 @@ def run_bench(dataset_id, testset_id):
     # check dataset is fit with testset
     # check dataset is ready
     bench_id = db.insert_benchmark(dataset_id, testset_id, target_type)
-    # insert result
-    for _ in range(5):
-        db.insert_performance(bench_id, "Pass, 1.5MB, 10s, ...")
+    sess = db.get_session(dataset_id, testset_id)
+    for data, label in db.get_all_item(dataset_id):
+        sess.load(data)
+        result = sess.run(data)
+        sess.clean(data)
+        ms, save = sess.verify(result, label)
+        # db.insert_performance(bench_id, ms, data, label, save)