From 6db8f7a70920f91418078fe09477eed0b0adefdb Mon Sep 17 00:00:00 2001 From: Kefei Lu Date: Thu, 2 Sep 2021 01:17:56 -0700 Subject: [PATCH] Fix TRTModule not adding outputs in order (#64418) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64418 In T99368564, we found that when running TRT lowered module, the output tensors are out-of-order, as compared to the output from the original, non-lowered module. It turns out that in `TRTModule.forward()`, we cannot rely on `ICudaEngine` bindings natural order indices to create the output tensors, but rather, we should explicitly construct the output tensor from the bindings' names, in an ordered that we supply. Test Plan: * Arc lint * Run CI/sandcastle tests * Run GPU lowering using commands and code changes in D30171741 and ensure we don't observe out-of-order outputs Reviewed By: yinghai Differential Revision: D30693545 fbshipit-source-id: 32a894ceeb148fcf4e8d279be3835c7d1f1aa2ba --- torch/fx/experimental/fx2trt/fx2trt.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index f1d17e7..4c0b44c 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -1,5 +1,5 @@ import warnings -from typing import List, NamedTuple, Iterable, Any, Optional, Tuple +from typing import List, NamedTuple, Iterable, Any, Optional, Tuple, Sequence import tensorrt as trt import torch @@ -53,6 +53,12 @@ class TRTModule(torch.nn.Module): # Indicate output is in fp16 self.fp16_output = fp16_output + # Indices of outputs into the CUDA engine bindings, in the order as they are + # in the fx graph's `output` node. + self.output_indices_in_order: Sequence[int] = [ + self.engine.get_binding_index(name) for name in self.output_names + ] + def _on_state_dict(self, state_dict, prefix, local_metadata): state_dict[prefix + "engine"] = bytearray(self.engine.serialize()) state_dict[prefix + "input_names"] = self.input_names @@ -96,7 +102,7 @@ class TRTModule(torch.nn.Module): # create output tensors outputs: List[torch.Tensor] = [] - for idx in range(len(inputs), len(inputs) + len(self.output_names)): + for idx in self.output_indices_in_order: dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) if self.engine.has_implicit_batch_dimension: -- 2.7.4