from ..ty import TupleType, TensorType, Any
from ..loops import while_loop
from .. import transform
-from .common import get_relay_op
+from .common import AttrCvt, get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
return _op.meshgrid(data, indexing="ij")
return _impl
+
+def _nms(prelude):
+ def _impl(inputs, input_types):
+ boxes = inputs[0]
+ scores = inputs[1]
+ iou_threshold = inputs[2]
+
+ # Generate data with shape (1, num_anchors, 5)
+ scores = AttrCvt(op_name="expand_dims",
+ extras={'axis': -1, 'num_newaxis': 1})([scores], {})
+
+ # Prepare input data for get_valid_counts
+ data = _op.concatenate([scores, boxes], -1)
+ data = _op.expand_dims(data, 0, 1)
+ # Leverage get_valid_counts to sort the data and clear invalid boxes
+ ct, data, indices = get_relay_op('get_valid_counts')(data,
+ score_threshold=-1.0,
+ id_index=-1,
+ score_index=0)
+
+ # Perform Non-Maximum Suppression,
+ # PyTorch NMS doesn't have parameter top_k and max_output_size
+ score_index = 0
+ top_k = max_out_size = -1
+ nms_ret = get_relay_op('non_max_suppression')(data=data,
+ valid_count=ct,
+ indices=indices,
+ max_output_size=max_out_size,
+ iou_threshold=iou_threshold,
+ force_suppress=True,
+ top_k=top_k,
+ coord_start=1,
+ score_index=score_index,
+ id_index=-1,
+ return_indices=True,
+ invalid_to_bottom=False)
+
+ # squeeze the two outputs of nms for strided_slice
+ size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
+ data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])
+
+ # strided slice to get the dynamic result
+ return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
+ end=size, slice_mode="size")
+ return _impl
+
+
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
import torch
"aten::gather" : _gather(),
"aten::index_select" : _select(),
"aten::index" : _index(),
+ "torchvision::nms" : _nms(prelude),
}
return convert_map
verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp)
+def test_forward_nms():
+ """dynamic Non-Maximum Suppression"""
+ torch.set_grad_enabled(False)
+ class NonMaxSupression(Module):
+ def __init__(self, iou_thres):
+ super().__init__()
+ self.iou_threshold = iou_thres
+
+ def forward(self, *args):
+ return torchvision.ops.nms(args[0], args[1], self.iou_threshold)
+
+ # Generate random input data
+ def _gen_rand_inputs(num_boxes):
+ box_len = 4
+ boxes = torch.rand(num_boxes, box_len, dtype=torch.float) * 0.5
+ boxes[:, 2] += boxes[:, 0]
+ boxes[:, 3] += boxes[:, 1]
+ scores = torch.rand(num_boxes, dtype=torch.float)
+ return boxes, scores
+
+ for num_boxes, iou_thres in [(10, 0.3), (100, 0.5), (500, 0.9)]:
+ in_boxes, in_scores = _gen_rand_inputs(num_boxes)
+ verify_trace_model(NonMaxSupression(iou_thres), [in_boxes, in_scores])
+
+
def test_conv3d():
for ishape in [(1, 32, 16, 16, 16),
(1, 32, 9, 15, 15),
def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
+ verify_model_vm(script_module, ishapes)
- input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
- input_shapes = list(zip(input_names, ishapes))
- inputs = [torch.randn(shape, dtype=torch.float)
- for shape in ishapes]
+def verify_trace_model(pt_model, idata):
+ traced_model = torch.jit.trace(pt_model, idata)
+ ishapes = [data.shape for data in idata]
+ verify_model_vm(traced_model, ishapes, idata=idata)
- mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+
+def verify_model_vm(imodel, ishapes, idtype=torch.float, idata=None):
+ input_model = imodel
+ input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
+ input_shapes = list(zip(input_names, ishapes))
+ input_data = idata if idata else [torch.randn(shape, dtype=idtype)
+ for shape in ishapes]
+ # Compile via VM
+ mod, params = relay.frontend.from_pytorch(input_model, input_shapes)
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
target="llvm")
evaluator = executor.evaluate()
- for name, inp in zip(input_names, inputs):
+ # Inference
+ for name, inp in zip(input_names, input_data):
params[name] = inp.numpy()
+ vm_res = evaluator(**params)
- op_res = evaluator(**params)
-
+ # Baseline result
with torch.no_grad():
- pt_result = pt_model(*inputs)
+ pt_result = input_model(*input_data)
+ # Verify the accuracy
if not isinstance(pt_result, torch.Tensor):
- tvm_res = op_res.asnumpy().item()
+ tvm_res = vm_res.asnumpy().item()
assert pt_result == tvm_res
else:
- tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
+ tvm.testing.assert_allclose(vm_res.asnumpy(), pt_result.numpy(),
rtol=1e-5, atol=1e-5)
test_forward_gather()
test_upsample()
test_forward_upsample3d()
+ test_forward_nms()
test_to()
test_type_as()
test_forward_functional_pad()