From: Pavel Esir Date: Thu, 12 Nov 2020 12:04:07 +0000 (+0300) Subject: add getting nms_threshold / iou_threshold from RetinaNet (#3075) X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8c89d8d7331ed4ebbc0ad69578fd6e31f0b668d1;p=platform%2Fupstream%2Fdldt.git add getting nms_threshold / iou_threshold from RetinaNet (#3075) * added getting nms_threshold/iou_threshold from original TF RetinaNet model * iou_threshold definition added * fixed getting iou_threshold for TF NMS V2, some minor corrections * added box_encoding to NMS extractors --- diff --git a/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py b/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py index 3914e4c..c214393 100644 --- a/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py +++ b/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py @@ -79,11 +79,10 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) - concat_value = Const(graph, {'value': np.array([4])}).create_node() - shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2, - 'axis': np.array(0)}).create_node() - shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0)) - concat_value.out_port(0).connect(shape_concat.in_port(1)) + shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), + {'name': name + '/shape_for_tiling', 'in_ports_count': 2, + 'axis': int64_array(0)}, + shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() @@ -246,9 +245,19 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr applied_width_height_regressions_node) detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes) + # get nms from the original network + iou_threshold = None + nms_nodes = graph.get_op_nodes(op='NonMaxSuppression') + if len(nms_nodes) > 0: + # it is highly unlikely that for different classes NMS has different + # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold) + iou_threshold = nms_nodes[0].in_node(3).value + if iou_threshold is None: + raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id)) + detection_output_node = detection_output_op.create_node( [reshape_regression_node, reshape_classes_node, priors], - dict(name=detection_output_op.attrs['type'], clip_after_nms=1, normalized=1, variance_encoded_in_target=0, - background_label_id=1000)) + dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1, + variance_encoded_in_target=0, background_label_id=1000)) return {'detection_output_node': detection_output_node} diff --git a/model-optimizer/extensions/front/tf/non_max_suppression_ext.py b/model-optimizer/extensions/front/tf/non_max_suppression_ext.py index 2e3930f..f56b5ba 100644 --- a/model-optimizer/extensions/front/tf/non_max_suppression_ext.py +++ b/model-optimizer/extensions/front/tf/non_max_suppression_ext.py @@ -21,13 +21,24 @@ from extensions.ops.non_max_suppression import NonMaxSuppression from mo.front.extractor import FrontExtractorOp +class NonMaxSuppressionV2Extractor(FrontExtractorOp): + op = 'NonMaxSuppressionV2' + enabled = True + + @classmethod + def extract(cls, node): + attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32} + NonMaxSuppression.update_node_stat(node, attrs) + return cls.enabled + + class NonMaxSuppressionV3Extractor(FrontExtractorOp): op = 'NonMaxSuppressionV3' enabled = True @classmethod def extract(cls, node): - attrs = {'sort_result_descending': 1, 'center_point_box': 0, 'output_type': np.int32} + attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32} NonMaxSuppression.update_node_stat(node, attrs) return cls.enabled diff --git a/model-optimizer/extensions/front/tf/retinanet.json b/model-optimizer/extensions/front/tf/retinanet.json index 4687e6a..0254fe9 100644 --- a/model-optimizer/extensions/front/tf/retinanet.json +++ b/model-optimizer/extensions/front/tf/retinanet.json @@ -7,7 +7,6 @@ "confidence_threshold": 0.05, "top_k": 6000, "keep_top_k": 300, - "nms_threshold": 0.5, "variance": [0.2, 0.2, 0.2, 0.2] }, "include_inputs_to_sub_graph": true, diff --git a/model-optimizer/extensions/ops/non_max_suppression.py b/model-optimizer/extensions/ops/non_max_suppression.py index b73b059..6777bcf 100644 --- a/model-optimizer/extensions/ops/non_max_suppression.py +++ b/model-optimizer/extensions/ops/non_max_suppression.py @@ -34,7 +34,6 @@ class NonMaxSuppression(Op): 'version': 'opset5', 'infer': self.infer, 'output_type': np.int64, - 'center_point_box': 0, 'box_encoding': 'corner', 'in_ports_count': 5, 'sort_result_descending': 1,