From f4f2c1231a4a429182d12b58e16b4a42fcae68a5 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Mon, 16 Aug 2021 15:16:51 -0700 Subject: [PATCH] [fx2trt] add unsqueeze converter (#63355) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63355 Added converter for acc_ops.unsqueeze. Needed for ig model. DIdn't add support for input that has more than one dynamic dim. This is not needed right now and I feel it would be a rare case. Test Plan: unit test Reviewed By: yinghai Differential Revision: D30138293 fbshipit-source-id: 899fe8eb68387de83195a2f6e199618d96f09a9e --- .../fx2trt/converters/acc_ops_converters.py | 20 ++++++++++++++++++++ torch/fx/experimental/fx2trt/fx2trt.py | 4 +++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index 6be0e6f..2433421 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -698,6 +698,26 @@ def acc_ops_min_two_tensors_input(network, target, args, kwargs, name): ) +@tensorrt_converter(acc_ops.unsqueeze) +def acc_ops_unsqueeze(network, target, args, kwargs, name): + input_val = kwargs["input"] + + if not isinstance(input_val, trt.tensorrt.ITensor): + raise RuntimeError(f"unsqueeze received input {input_val} that is not part " + "of the TensorRT region!") + + dim = kwargs["dim"] + if network.has_implicit_batch_dimension: + assert dim != 0 + dim -= 1 + + assert len(get_dynamic_dims(input_val.shape)) <= 1, "Currently we don't support unsqueeze with more than one dynamic dims." + layer = network.add_shuffle(input_val) + layer.reshape_dims = tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:] + layer.name = name + return layer.get_output(0) + + @tensorrt_converter(acc_ops.adaptive_avg_pool2d) def acc_ops_adaptive_avg_pool2d(network, target, args, kwargs, name): input_val = kwargs["input"] diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index 6054586..160b4a7 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -197,7 +197,9 @@ def create_inputs_from_specs(input_specs): elif not has_batch_dim: shape = (1,) + tuple(shape) - inputs.append(torch.empty(shape, dtype=dtype, device=device)) + inputs.append( + torch.randn(shape).to(dtype=dtype, device=device) + ) return inputs -- 2.7.4