[Relay] Support for PyTorch Non-Maximum Suppression (#6314)
authorYong Wu <ywu118@alumni.jh.edu>
Mon, 24 Aug 2020 13:19:57 +0000 (21:19 +0800)
committerGitHub <noreply@github.com>
Mon, 24 Aug 2020 13:19:57 +0000 (22:19 +0900)
* [Relay] Support for PyTorch Non-Maximum Suppression

* fix comment

* add verify_model_vm

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index 7237403..21cf9c3 100644 (file)
@@ -32,7 +32,7 @@ from .. import op as _op
 from ..ty import TupleType, TensorType, Any
 from ..loops import while_loop
 from .. import transform
-from .common import get_relay_op
+from .common import AttrCvt, get_relay_op
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
@@ -1811,6 +1811,53 @@ def _meshgrid():
         return _op.meshgrid(data, indexing="ij")
     return _impl
 
+
+def _nms(prelude):
+    def _impl(inputs, input_types):
+        boxes = inputs[0]
+        scores = inputs[1]
+        iou_threshold = inputs[2]
+
+        # Generate data with shape (1, num_anchors, 5)
+        scores = AttrCvt(op_name="expand_dims",
+                         extras={'axis': -1, 'num_newaxis': 1})([scores], {})
+
+        # Prepare input data for get_valid_counts
+        data = _op.concatenate([scores, boxes], -1)
+        data = _op.expand_dims(data, 0, 1)
+        # Leverage get_valid_counts to sort the data and clear invalid boxes
+        ct, data, indices = get_relay_op('get_valid_counts')(data,
+                                                             score_threshold=-1.0,
+                                                             id_index=-1,
+                                                             score_index=0)
+
+        # Perform Non-Maximum Suppression,
+        # PyTorch NMS doesn't have parameter top_k and max_output_size
+        score_index = 0
+        top_k = max_out_size = -1
+        nms_ret = get_relay_op('non_max_suppression')(data=data,
+                                                      valid_count=ct,
+                                                      indices=indices,
+                                                      max_output_size=max_out_size,
+                                                      iou_threshold=iou_threshold,
+                                                      force_suppress=True,
+                                                      top_k=top_k,
+                                                      coord_start=1,
+                                                      score_index=score_index,
+                                                      id_index=-1,
+                                                      return_indices=True,
+                                                      invalid_to_bottom=False)
+
+        # squeeze the two outputs of nms for strided_slice
+        size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
+        data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
+
+        # strided slice to get the dynamic result
+        return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
+                                             end=size, slice_mode="size")
+    return _impl
+
+
 def _pytorch_result_type(dtypes, non_tensor_inputs):
     """This promotes TVM dtypes like PyTorch would"""
     import torch
@@ -2111,6 +2158,7 @@ def _get_convert_map(prelude):
         "aten::gather"                          : _gather(),
         "aten::index_select"                    : _select(),
         "aten::index"                           : _index(),
+        "torchvision::nms"                      : _nms(prelude),
     }
     return convert_map
 
index ab0a4b0..946712d 100644 (file)
@@ -1428,6 +1428,31 @@ def test_forward_upsample3d():
     verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp)
 
 
+def test_forward_nms():
+    """dynamic Non-Maximum Suppression"""
+    torch.set_grad_enabled(False)
+    class NonMaxSupression(Module):
+        def __init__(self, iou_thres):
+            super().__init__()
+            self.iou_threshold = iou_thres
+
+        def forward(self, *args):
+            return torchvision.ops.nms(args[0], args[1], self.iou_threshold)
+
+    # Generate random input data
+    def _gen_rand_inputs(num_boxes):
+        box_len = 4
+        boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5
+        boxes[:, 2] += boxes[:, 0]
+        boxes[:, 3] += boxes[:, 1]
+        scores = torch.rand(num_boxes, dtype=torch.float)
+        return boxes, scores
+
+    for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
+        in_boxes, in_scores = _gen_rand_inputs(num_boxes)
+        verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores])
+
+
 def test_conv3d():
     for ishape in [(1, 32, 16, 16, 16),
                    (1, 32, 9, 15, 15),
@@ -1577,32 +1602,43 @@ def test_3d_models():
 
 def verify_script_model(pt_model, ishapes):
     script_module = torch.jit.script(pt_model)
+    verify_model_vm(script_module, ishapes)
 
-    input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
-    input_shapes = list(zip(input_names, ishapes))
 
-    inputs = [torch.randn(shape, dtype=torch.float)
-              for shape in ishapes]
+def verify_trace_model(pt_model, idata):
+    traced_model = torch.jit.trace(pt_model, idata)
+    ishapes = [data.shape for data in idata]
+    verify_model_vm(traced_model, ishapes, idata=idata)
 
-    mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+
+def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
+    input_model = imodel
+    input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
+    input_shapes = list(zip(input_names, ishapes))
+    input_data = idata if idata else [torch.randn(shape, dtype=idtype)
+                                      for shape in ishapes]
+    # Compile via VM
+    mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
 
     executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
                                      target="llvm")
     evaluator = executor.evaluate()
 
-    for name, inp in zip(input_names, inputs):
+    # Inference
+    for name, inp in zip(input_names, input_data):
         params[name] = inp.numpy()
+    vm_res = evaluator(**params)
 
-    op_res = evaluator(**params)
-
+    # Baseline result
     with torch.no_grad():
-        pt_result = pt_model(*inputs)
+        pt_result = input_model(*input_data)
 
+    # Verify the accuracy
     if not isinstance(pt_result, torch.Tensor):
-        tvm_res = op_res.asnumpy().item()
+        tvm_res = vm_res.asnumpy().item()
         assert pt_result == tvm_res
     else:
-        tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
+        tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
                                     rtol=1e-5, atol=1e-5)
 
 
@@ -2863,6 +2899,7 @@ if __name__ == "__main__":
     test_forward_gather()
     test_upsample()
     test_forward_upsample3d()
+    test_forward_nms()
     test_to()
     test_type_as()
     test_forward_functional_pad()