[fx2trt] add unsqueeze converter (#63355)
authorShiyan Deng <dsy842974287@fb.com>
Mon, 16 Aug 2021 22:16:51 +0000 (15:16 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 22:18:43 +0000 (15:18 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63355

Added converter for acc_ops.unsqueeze. Needed for ig model.

DIdn't add support for input that has more than one dynamic dim. This is not needed right now and I feel it would be a rare case.

Test Plan: unit test

Reviewed By: yinghai

Differential Revision: D30138293

fbshipit-source-id: 899fe8eb68387de83195a2f6e199618d96f09a9e

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py
torch/fx/experimental/fx2trt/fx2trt.py

index 6be0e6f..2433421 100644 (file)
@@ -698,6 +698,26 @@ def acc_ops_min_two_tensors_input(network, target, args, kwargs, name):
     )
 
 
+@tensorrt_converter(acc_ops.unsqueeze)
+def acc_ops_unsqueeze(network, target, args, kwargs, name):
+    input_val = kwargs["input"]
+
+    if not isinstance(input_val, trt.tensorrt.ITensor):
+        raise RuntimeError(f"unsqueeze received input {input_val} that is not part "
+                           "of the TensorRT region!")
+
+    dim = kwargs["dim"]
+    if network.has_implicit_batch_dimension:
+        assert dim != 0
+        dim -= 1
+
+    assert len(get_dynamic_dims(input_val.shape)) <= 1, "Currently we don't support unsqueeze with more than one dynamic dims."
+    layer = network.add_shuffle(input_val)
+    layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:]
+    layer.name = name
+    return layer.get_output(0)
+
+
 @tensorrt_converter(acc_ops.adaptive_avg_pool2d)
 def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name):
     input_val = kwargs["input"]
index 6054586..160b4a7 100644 (file)
@@ -197,7 +197,9 @@ def create_inputs_from_specs(input_specs):
         elif not has_batch_dim:
             shape = (1,) + tuple(shape)
 
-        inputs.append(torch.empty(shape, dtype=dtype, device=device))
+        inputs.append(
+            torch.randn(shape).to(dtype=dtype, device=device)
+        )
 
     return inputs