Samples DNN: tf_text_graph_sd.py loads box coder variance and box NMS params from...
authorLorenzo Lucignano <lorenzo.lucignano@tiitoo.com>
Wed, 20 Nov 2019 09:45:57 +0000 (10:45 +0100)
committerLorenzo Lucignano <lorenzo.lucignano@tiitoo.com>
Wed, 20 Nov 2019 09:45:57 +0000 (10:45 +0100)
samples/dnn/tf_text_graph_ssd.py

index e6017b2..905f751 100644 (file)
@@ -283,6 +283,9 @@ def createSSDGraph(modelPath, configPath, outputPath):
 
     # Add layers that generate anchors (bounding boxes proposals).
     priorBoxes = []
+    boxCoder = config['box_coder'][0]
+    fasterRcnnBoxCoder = boxCoder['faster_rcnn_box_coder'][0]
+    boxCoderVariance = [1.0/float(fasterRcnnBoxCoder['x_scale'][0]), 1.0/float(fasterRcnnBoxCoder['y_scale'][0]), 1.0/float(fasterRcnnBoxCoder['width_scale'][0]), 1.0/float(fasterRcnnBoxCoder['height_scale'][0])]
     for i in range(num_layers):
         priorBox = NodeDef()
         priorBox.name = 'PriorBox_%d' % i
@@ -303,7 +306,7 @@ def createSSDGraph(modelPath, configPath, outputPath):
 
         priorBox.addAttr('width', widths)
         priorBox.addAttr('height', heights)
-        priorBox.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
+        priorBox.addAttr('variance', boxCoderVariance)
 
         graph_def.node.extend([priorBox])
         priorBoxes.append(priorBox.name)
@@ -336,11 +339,31 @@ def createSSDGraph(modelPath, configPath, outputPath):
     detectionOut.addAttr('num_classes', num_classes + 1)
     detectionOut.addAttr('share_location', True)
     detectionOut.addAttr('background_label_id', 0)
-    detectionOut.addAttr('nms_threshold', 0.6)
-    detectionOut.addAttr('top_k', 100)
+
+    postProcessing = config['post_processing'][0]
+    batchNMS = postProcessing['batch_non_max_suppression'][0]
+
+    if 'iou_threshold' in batchNMS:
+        detectionOut.addAttr('nms_threshold', float(batchNMS['iou_threshold'][0]))
+    else:
+        detectionOut.addAttr('nms_threshold', 0.6)
+
+    if 'score_threshold' in batchNMS:
+        detectionOut.addAttr('confidence_threshold', float(batchNMS['score_threshold'][0]))
+    else:
+        detectionOut.addAttr('confidence_threshold', 0.01)
+
+    if 'max_detections_per_class' in batchNMS:
+        detectionOut.addAttr('top_k', int(batchNMS['max_detections_per_class'][0]))
+    else:
+        detectionOut.addAttr('top_k', 100)
+
+    if 'max_total_detections' in batchNMS:
+        detectionOut.addAttr('keep_top_k', int(batchNMS['max_total_detections'][0]))
+    else:
+        detectionOut.addAttr('keep_top_k', 100)
+
     detectionOut.addAttr('code_type', "CENTER_SIZE")
-    detectionOut.addAttr('keep_top_k', 100)
-    detectionOut.addAttr('confidence_threshold', 0.01)
 
     graph_def.node.extend([detectionOut])