add getting nms_threshold / iou_threshold from RetinaNet (#3075)
authorPavel Esir <pavel.esir@intel.com>
Thu, 12 Nov 2020 12:04:07 +0000 (15:04 +0300)
committerGitHub <noreply@github.com>
Thu, 12 Nov 2020 12:04:07 +0000 (15:04 +0300)
* added getting nms_threshold/iou_threshold from original TF RetinaNet model

* iou_threshold definition added

* fixed getting iou_threshold for TF NMS V2, some minor corrections

* added box_encoding to NMS extractors

model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py
model-optimizer/extensions/front/tf/non_max_suppression_ext.py
model-optimizer/extensions/front/tf/retinanet.json
model-optimizer/extensions/ops/non_max_suppression.py

index 3914e4c..c214393 100644 (file)
@@ -79,11 +79,10 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
         end.out_port(0).connect(shape_part_for_tiling.in_port(2))
         stride.out_port(0).connect(shape_part_for_tiling.in_port(3))
 
-        concat_value = Const(graph, {'value': np.array([4])}).create_node()
-        shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
-                                      'axis': np.array(0)}).create_node()
-        shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
-        concat_value.out_port(0).connect(shape_concat.in_port(1))
+        shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
+                                                        {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
+                                                         'axis': int64_array(0)},
+                                                        shape_part_for_tiling)
 
         variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
         tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
@@ -246,9 +245,19 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
                                                                    applied_width_height_regressions_node)
 
         detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
+        # get nms from the original network
+        iou_threshold = None
+        nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
+        if len(nms_nodes) > 0:
+            # it is highly unlikely that for different classes NMS has different
+            # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
+            iou_threshold = nms_nodes[0].in_node(3).value
+        if iou_threshold is None:
+            raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))
+
         detection_output_node = detection_output_op.create_node(
             [reshape_regression_node, reshape_classes_node, priors],
-            dict(name=detection_output_op.attrs['type'], clip_after_nms=1, normalized=1, variance_encoded_in_target=0,
-                 background_label_id=1000))
+            dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
+                 variance_encoded_in_target=0, background_label_id=1000))
 
         return {'detection_output_node': detection_output_node}
index 2e3930f..f56b5ba 100644 (file)
@@ -21,13 +21,24 @@ from extensions.ops.non_max_suppression import NonMaxSuppression
 from mo.front.extractor import FrontExtractorOp
 
 
+class NonMaxSuppressionV2Extractor(FrontExtractorOp):
+    op = 'NonMaxSuppressionV2'
+    enabled = True
+
+    @classmethod
+    def extract(cls, node):
+        attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
+        NonMaxSuppression.update_node_stat(node, attrs)
+        return cls.enabled
+
+
 class NonMaxSuppressionV3Extractor(FrontExtractorOp):
     op = 'NonMaxSuppressionV3'
     enabled = True
 
     @classmethod
     def extract(cls, node):
-        attrs = {'sort_result_descending': 1, 'center_point_box': 0, 'output_type': np.int32}
+        attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
         NonMaxSuppression.update_node_stat(node, attrs)
         return cls.enabled
 
index 4687e6a..0254fe9 100644 (file)
@@ -7,7 +7,6 @@
             "confidence_threshold": 0.05,
             "top_k": 6000,
             "keep_top_k": 300,
-            "nms_threshold": 0.5,
             "variance": [0.2, 0.2, 0.2, 0.2]
         },
         "include_inputs_to_sub_graph": true,
index b73b059..6777bcf 100644 (file)
@@ -34,7 +34,6 @@ class NonMaxSuppression(Op):
             'version': 'opset5',
             'infer': self.infer,
             'output_type': np.int64,
-            'center_point_box': 0,
             'box_encoding': 'corner',
             'in_ports_count': 5,
             'sort_result_descending': 1,