From: Chris Sullivan Date: Mon, 27 Jul 2020 17:38:52 +0000 (-0700) Subject: [Relay][OP] Support NMSv4 ingestion from TF. (#6085) X-Git-Tag: upstream/0.7.0~351 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=abc52aae75bf12a8839cc509fe2547d1b4629bd0;p=platform%2Fupstream%2Ftvm.git [Relay][OP] Support NMSv4 ingestion from TF. (#6085) --- diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5f52553..aa62702 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -637,10 +637,11 @@ def _nms(): iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0] # score_threshold was introduced from V3 score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0 + pad_output = 'pad_to_max_output_size' # Generate data with shape (1, num_anchors, 5) scores = AttrCvt(op_name="expand_dims", - ignores=['T_threshold'], + ignores=['T_threshold', pad_output], extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr) data = get_relay_op('concatenate')([scores, inputs[0]], -1) data = get_relay_op('expand_dims')(data, 0, 1) @@ -667,6 +668,8 @@ def _nms(): return_indices=True, invalid_to_bottom=False) + if pad_output in attr and attr[pad_output]: + return nms_ret # squeeze it, TF NMS is not batched size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) @@ -2152,6 +2155,7 @@ _convert_map = { 'Neg' : AttrCvt('negative'), 'NonMaxSuppressionV2' : _nms(), 'NonMaxSuppressionV3' : _nms(), + 'NonMaxSuppressionV4' : _nms(), 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 5c6bd6f..62829df 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2031,12 +2031,31 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], 'nms/NonMaxSuppressionV3:0', mode='debug') -def test_forward_nms_v3(): - """ NonMaxSuppressionV3 """ - _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5) - _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10) - _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000) - _test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7) +def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"): + boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype) + scores = np.random.uniform(size=score_shape).astype(dtype) + max_output_size = np.int32(out_size) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2") + in_data_3 = tf.placeholder(tf.int32, name="in_data_3") + indices_padded, num_valid = tf.image.non_max_suppression_padded(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3, + iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms", pad_to_max_output_size=True) + num_valid = tf.reshape(num_valid,shape=(-1,)) + indices_padded = tf.reshape(indices_padded, shape=(-1,)) + tf.slice(indices_padded, tf.constant([0]), num_valid, name="SlicedIndices") + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], + ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='vm') + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], + ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='debug') + +def test_forward_nms(): + """ NonMaxSuppressionV3,4 """ + for _test_forward_nms in [_test_forward_nms_v3, _test_forward_nms_v4]: + _test_forward_nms((5, 4), (5,), 0.7, 0.5, 5) + _test_forward_nms((20, 4), (20,), 0.5, 0.6, 10) + _test_forward_nms((1000, 4), (1000,), 0.3, 0.7, 1000) + _test_forward_nms((2000, 4), (2000,), 0.4, 0.6, 7) ####################################################################### @@ -3867,7 +3886,7 @@ if __name__ == '__main__': test_forward_truncatemod() test_forward_one_hot() test_forward_atan2() - test_forward_nms_v3() + test_forward_nms() # Activations test_forward_sigmoid()