Change load data method 56/257056/1
authorKwang Son <k.son@samsung.com>
Tue, 13 Apr 2021 07:36:45 +0000 (16:36 +0900)
committerKwang Son <k.son@samsung.com>
Mon, 19 Apr 2021 04:58:41 +0000 (13:58 +0900)
 - Change method to download single data of dataset from NAS server

Change-Id: I69c46a9ab1761bdf80e3ff4913dad93de8afcaef
Signed-off-by: Kwang Son <k.son@samsung.com>
script/mvqa/db.py
script/mvqa/run.py
script/nas_config.json [new file with mode: 0644]

index 2394f0d..f7c35b5 100644 (file)
@@ -1,13 +1,12 @@
 import sqlite3
 import os
-import glob
 import subprocess
+import json
 
 DB_PATH = 'mvqa.db'
-WIDER_FACE_PREFIX = '/opt/mvdata/WIDER_face/WIDER_val/images/'
 
-dataset = [('Artelab', 430),
-           ('WIDER_face', 3226),
+dataset = [('Artelab',),
+           ('WIDER_face',),
            ]
 
 testset = [('barcode detect',),
@@ -28,9 +27,8 @@ def create():
     c.execute(
         '''CREATE TABLE 'dataset'(
             id integer PRIMARY KEY,
-            name text NOT NULL,
-            size integer NOT NULL)''')
-    c.executemany('INSERT INTO dataset VALUES (null,?,?)', dataset)
+            name text NOT NULL)''')
+    c.executemany('INSERT INTO dataset VALUES (null,?)', dataset)
     conn.commit()
 
     c.execute(
@@ -67,7 +65,7 @@ def print_dataset_list():
     c = conn.cursor()
     ret = c.execute('SELECT * FROM dataset')
     for row in ret:
-        print(row[:2])
+        print(row)
     conn.close()
 
 
@@ -107,38 +105,16 @@ def insert_performance(bench_id, time, data, label, text):
     conn.close()
 
 
-def get_all_item(dataset_id):
-    items = []
-    if (dataset_id is 1):
-        images = glob.glob('/opt/mvdata/ArteLab1D/BarcodeDatasets/*/*.jpg')
-        for image in images:
-            items.append((image, image + '.txt'))
-        return items
-    elif (dataset_id is 2):
-        with open('/opt/mvdata/WIDER_face/wider_face_split/wider_face_val_bbx_gt.txt') as f:
-            bbox = f.readlines()
-        idx = 0
-        while(idx < len(bbox)):
-            mdata = bbox[idx][:-1]
-            idx += 1
-            cnt = int(bbox[idx])
-            idx += 1
-            mlabel = bbox[idx:idx+cnt]
-            slabel = ''
-            for ele in mlabel:
-                slabel += ele
-            idx += cnt
-            items.append((mdata, slabel))
-        return items
-    else:
-        raise NotImplementedError
-
-
 class Session:
     def __init__(self, dataset_id, testset_id):
         self.dataset_id = dataset_id
         self.testset_id = testset_id
+        assert(os.path.exists('nas_config.json'))
+        with open('nas_config.json') as f:
+            nas_config = json.load(f)
+        self.nas_config = nas_config
         os.system('sdb shell rm -rf /opt/mvdata')
+        os.system('rm -f db_meta.json')
 
     def load(self):
         pass
@@ -152,27 +128,50 @@ class Session:
     def verify(self, result, label):
         pass
 
+    def get_db_meta(self):
+        self.remote_db = 'wget --user=' + self.nas_config['user'] + ' --password=' + self.nas_config['password'] + \
+            ' ftp://' + self.nas_config['ip'] + ':' + \
+            self.nas_config['db_index_path'] + '/' + \
+            dataset[self.dataset_id - 1][0] + '/'
+        os.system(self.remote_db + 'db_meta.json')
+        with open('db_meta.json') as f:
+            meta = json.load(f)
+        os.system('rm db_meta.json')
+        return meta['data']
+
 
 class Barcode(Session):
     def __init__(self, dataset_id, testset_id):
         super().__init__(dataset_id, testset_id)
 
     def load(self, item):
-        os.system('sdb push ' + item + ' ' + item)
+        item = item.replace('(', '\(')
+        item = item.replace(')', '\)')
+        file_path = os.path.basename(item)
+        os.system(self.remote_db + item)
+        os.system('sdb push ' + file_path + ' /tmp/' + file_path)
 
     def run(self, item):
-        return subprocess.check_output('sdb shell mv_barcode_assessment ' + item, shell=True)
+        item = item.replace('(', '\(')
+        item = item.replace(')', '\)')
+        file_path = os.path.basename(item)
+        return subprocess.check_output('sdb shell mv_barcode_assessment "/tmp/' + file_path + '"', shell=True)
 
     def clean(self, item):
-        os.system('sdb shell rm ' + item)
+        item = item.replace('(', '\(')
+        item = item.replace(')', '\)')
+        file_path = os.path.basename(item)
+        os.system('sdb shell rm "/tmp/' + file_path + '"')
+        os.system('rm ' + file_path)
 
     def verify(self, result, label):
         reform = result.decode('utf-8').split()
-        with open(label) as f:
-            ans = f.read().split()
-        print('result : {} label : {} reform : {} ans : {}'.format(
-            result, label, reform, ans))
-        if(len(ans) != int(reform[0])):
+        ans = label.split()
+        try:
+            cnt = int(reform[0])
+        except:
+            return int(reform[-1][:-2]), 'Fatal'
+        if(len(ans) != cnt):
             return int(reform[-1][:-2]), 'Fail'
         else:
             for idx, code in enumerate(ans):
@@ -184,27 +183,25 @@ class Barcode(Session):
 class FaceCascadeDetection(Session):
     def __init__(self, dataset_id, testset_id):
         super().__init__(dataset_id, testset_id)
-        dirs = glob.glob('/opt/mvdata/WIDER_face/WIDER_val/images/*')
-        for dir in dirs:
-            os.system('sdb shell mkdir -p ' + dir)
 
     def load(self, item):
-        os.system('sdb push ' + WIDER_FACE_PREFIX +
-                  item + ' ' + WIDER_FACE_PREFIX + item)
+        file_path = os.path.basename(item)
+        os.system(self.remote_db + item)
+        os.system('sdb push ' + file_path + ' /tmp/' + file_path)
 
     def run(self, item):
-        return subprocess.check_output('sdb shell mv_face_assessment ' + WIDER_FACE_PREFIX + item + ' 0', shell=True)
+        file_path = os.path.basename(item)
+        return subprocess.check_output('sdb shell mv_face_assessment /tmp/' + file_path + ' 0', shell=True)
 
     def clean(self, item):
-        os.system('sdb shell rm ' + WIDER_FACE_PREFIX + item)
+        file_path = os.path.basename(item)
+        os.system('sdb shell rm /tmp/' + file_path)
+        os.system('rm ' + file_path)
 
     def verify(self, result, label):
-        """
-        https://github.com/Cartucho/mAP
-        """
         rv = result.decode('utf-8')
         reform = rv.split()
-        task_time = reform[-1][:-2]
+        task_time = int(reform[-1][:-2])
         return (task_time, rv)
 
 
index 80254d9..e8dc2eb 100644 (file)
@@ -12,11 +12,11 @@ def run_bench(dataset_id, testset_id):
         'sdb devices | tail -1 | cut -f 3', shell=True)
     # check dataset is fit with testset
     # check dataset is ready
-    bench_id = db.insert_benchmark(dataset_id, testset_id, target_type)
     sess = db.get_session(dataset_id, testset_id)
-    for data, label in db.get_all_item(dataset_id):
+    bench_id = db.insert_benchmark(dataset_id, testset_id, target_type)
+    for data, label in sess.get_db_meta():
         sess.load(data)
         result = sess.run(data)
         sess.clean(data)
         ms, save = sess.verify(result, label)
-        db.insert_performance(bench_id, ms, data, label, save)
+        db.insert_performance(bench_id, ms, data, label, save)
diff --git a/script/nas_config.json b/script/nas_config.json
new file mode 100644 (file)
index 0000000..77bf042
--- /dev/null
@@ -0,0 +1,6 @@
+{
+    "user": "",
+    "password": "",
+    "ip": "",
+    "db_index_path": ""
+}
\ No newline at end of file