[Relay][OP] Support NMSv4 ingestion from TF. (#6085)
authorChris Sullivan <csullivan@octoml.ai>
Mon, 27 Jul 2020 17:38:52 +0000 (10:38 -0700)
committerGitHub <noreply@github.com>
Mon, 27 Jul 2020 17:38:52 +0000 (10:38 -0700)
python/tvm/relay/frontend/tensorflow.py
tests/python/frontend/tensorflow/test_forward.py

index 5f52553..aa62702 100644 (file)
@@ -637,10 +637,11 @@ def _nms():
         iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0]
         # score_threshold was introduced from V3
         score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0
+        pad_output = 'pad_to_max_output_size'
 
         # Generate data with shape (1, num_anchors, 5)
         scores = AttrCvt(op_name="expand_dims",
-                         ignores=['T_threshold'],
+                         ignores=['T_threshold', pad_output],
                          extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr)
         data = get_relay_op('concatenate')([scores, inputs[0]], -1)
         data = get_relay_op('expand_dims')(data, 0, 1)
@@ -667,6 +668,8 @@ def _nms():
                                                       return_indices=True,
                                                       invalid_to_bottom=False)
 
+        if pad_output in attr and attr[pad_output]:
+            return nms_ret
         # squeeze it, TF NMS is not batched
         size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
         data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
@@ -2152,6 +2155,7 @@ _convert_map = {
     'Neg'                               : AttrCvt('negative'),
     'NonMaxSuppressionV2'               : _nms(),
     'NonMaxSuppressionV3'               : _nms(),
+    'NonMaxSuppressionV4'               : _nms(),
     'NoOp'                              : _no_op(),
     'NotEqual'                          : _broadcast('not_equal'),
     'OneHot'                            : _one_hot(),
index 5c6bd6f..62829df 100644 (file)
@@ -2031,12 +2031,31 @@ def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold,
     compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
                         'nms/NonMaxSuppressionV3:0', mode='debug')
 
-def test_forward_nms_v3():
-    """ NonMaxSuppressionV3 """
-    _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5)
-    _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10)
-    _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000)
-    _test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7)
+def _test_forward_nms_v4(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"):
+    boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype)
+    scores = np.random.uniform(size=score_shape).astype(dtype)
+    max_output_size = np.int32(out_size)
+    tf.reset_default_graph()
+    in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1")
+    in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2")
+    in_data_3 = tf.placeholder(tf.int32, name="in_data_3")
+    indices_padded, num_valid = tf.image.non_max_suppression_padded(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3,
+                                 iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms", pad_to_max_output_size=True)
+    num_valid = tf.reshape(num_valid,shape=(-1,))
+    indices_padded = tf.reshape(indices_padded, shape=(-1,))
+    tf.slice(indices_padded, tf.constant([0]), num_valid, name="SlicedIndices")
+    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
+                        ['nms/NonMaxSuppressionV4:1', "SlicedIndices:0"], mode='vm')
+    compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'],
+                        ['nms/NonMaxSuppressionV4:1',  "SlicedIndices:0"], mode='debug')
+
+def test_forward_nms():
+    """ NonMaxSuppressionV3,4 """
+    for _test_forward_nms in [_test_forward_nms_v3, _test_forward_nms_v4]:
+        _test_forward_nms((5, 4), (5,), 0.7, 0.5, 5)
+        _test_forward_nms((20, 4), (20,), 0.5, 0.6, 10)
+        _test_forward_nms((1000, 4), (1000,), 0.3, 0.7, 1000)
+        _test_forward_nms((2000, 4), (2000,), 0.4, 0.6, 7)
 
 
 #######################################################################
@@ -3867,7 +3886,7 @@ if __name__ == '__main__':
     test_forward_truncatemod()
     test_forward_one_hot()
     test_forward_atan2()
-    test_forward_nms_v3()
+    test_forward_nms()
 
     # Activations
     test_forward_sigmoid()