Fix TRTModule not adding outputs in order (#64418)
authorKefei Lu <kefeilu@fb.com>
Thu, 2 Sep 2021 08:17:56 +0000 (01:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 08:36:23 +0000 (01:36 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64418

In T99368564, we found that when running TRT lowered module, the output tensors are out-of-order, as compared to the output from the original, non-lowered module. It turns out that in `TRTModule.forward()`, we cannot rely on `ICudaEngine` bindings natural order indices to create the output tensors, but rather, we should explicitly construct the output tensor from the bindings' names, in an ordered that we supply.

Test Plan:
* Arc lint
* Run CI/sandcastle tests
* Run GPU lowering using commands and code changes in D30171741 and ensure we don't observe out-of-order outputs

Reviewed By: yinghai

Differential Revision: D30693545

fbshipit-source-id: 32a894ceeb148fcf4e8d279be3835c7d1f1aa2ba

torch/fx/experimental/fx2trt/fx2trt.py

index f1d17e7..4c0b44c 100644 (file)
@@ -1,5 +1,5 @@
 import warnings
-from typing import List, NamedTuple, Iterable, Any, Optional, Tuple
+from typing import List, NamedTuple, Iterable, Any, Optional, Tuple, Sequence
 
 import tensorrt as trt
 import torch
@@ -53,6 +53,12 @@ class TRTModule(torch.nn.Module):
         # Indicate output is in fp16
         self.fp16_output = fp16_output
 
+        # Indices of outputs into the CUDA engine bindings, in the order as they are
+        # in the fx graph's `output` node.
+        self.output_indices_in_order: Sequence[int] = [
+            self.engine.get_binding_index(name) for name in self.output_names
+        ]
+
     def _on_state_dict(self, state_dict, prefix, local_metadata):
         state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
         state_dict[prefix + "input_names"] = self.input_names
@@ -96,7 +102,7 @@ class TRTModule(torch.nn.Module):
 
         # create output tensors
         outputs: List[torch.Tensor] = []
-        for idx in range(len(inputs), len(inputs) + len(self.output_names)):
+        for idx in self.output_indices_in_order:
             dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
 
             if self.engine.has_implicit_batch_dimension: