Add fx2trt pass for removing duplicate output args (#64433)
authorKefei Lu <kefeilu@fb.com>
Thu, 2 Sep 2021 17:38:43 +0000 (10:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 17:40:37 +0000 (10:40 -0700)
commit6da7552a8eaae6b85e271bf3edac2fa2ae9f1148
tree7d044e5b757c28a435f0269274d12913497345a4
parentaeafcde087fa76618708b8a2841c450e7e184761
Add fx2trt pass for removing duplicate output args (#64433)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64433

Fx2TRT does not support duplicate nodes in the output args tuple.

This pass removes duplicate output args from the target subnets and fixes their uses in the top level module where the subnets are called. This pass must be called after acc split on the top-level net and subsequent calls to the acc trace on the subnets.

This pass will change both the subnets and top level module.

Test Plan:
Run:

```
buck run mode/opt -c python.package_style=inplace //caffe2/torch/fb/fx2trt/tests/passes/:test_remove_duplicate_output_args

```

Reviewed By: 842974287

Differential Revision: D30468409

fbshipit-source-id: b4d91b76ab5d8a5275d68dd48d1327a44c22568e
torch/fx/experimental/fx2trt/passes/remove_duplicate_output_args.py [new file with mode: 0644]