end.out_port(0).connect(shape_part_for_tiling.in_port(2))
stride.out_port(0).connect(shape_part_for_tiling.in_port(3))
- concat_value = Const(graph, {'value': np.array([4])}).create_node()
- shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
- 'axis': np.array(0)}).create_node()
- shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0))
- concat_value.out_port(0).connect(shape_concat.in_port(1))
+ shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]),
+ {'name': name + '/shape_for_tiling', 'in_ports_count': 2,
+ 'axis': int64_array(0)},
+ shape_part_for_tiling)
variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
applied_width_height_regressions_node)
detection_output_op = DetectionOutput(graph, match.custom_replacement_desc.custom_attributes)
+ # get nms from the original network
+ iou_threshold = None
+ nms_nodes = graph.get_op_nodes(op='NonMaxSuppression')
+ if len(nms_nodes) > 0:
+ # it is highly unlikely that for different classes NMS has different
+ # moreover DetectionOutput accepts only scalar values for iou_threshold (nms_threshold)
+ iou_threshold = nms_nodes[0].in_node(3).value
+ if iou_threshold is None:
+ raise Error('During {} `iou_threshold` was not retrieved from RetinaNet graph'.format(self.replacement_id))
+
detection_output_node = detection_output_op.create_node(
[reshape_regression_node, reshape_classes_node, priors],
- dict(name=detection_output_op.attrs['type'], clip_after_nms=1, normalized=1, variance_encoded_in_target=0,
- background_label_id=1000))
+ dict(name=detection_output_op.attrs['type'], nms_threshold=iou_threshold, clip_after_nms=1, normalized=1,
+ variance_encoded_in_target=0, background_label_id=1000))
return {'detection_output_node': detection_output_node}
from mo.front.extractor import FrontExtractorOp
+class NonMaxSuppressionV2Extractor(FrontExtractorOp):
+ op = 'NonMaxSuppressionV2'
+ enabled = True
+
+ @classmethod
+ def extract(cls, node):
+ attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
+ NonMaxSuppression.update_node_stat(node, attrs)
+ return cls.enabled
+
+
class NonMaxSuppressionV3Extractor(FrontExtractorOp):
op = 'NonMaxSuppressionV3'
enabled = True
@classmethod
def extract(cls, node):
- attrs = {'sort_result_descending': 1, 'center_point_box': 0, 'output_type': np.int32}
+ attrs = {'sort_result_descending': 1, 'box_encoding': 'corner', 'output_type': np.int32}
NonMaxSuppression.update_node_stat(node, attrs)
return cls.enabled