import warnings
-from typing import List, NamedTuple, Iterable, Any, Optional, Tuple
+from typing import List, NamedTuple, Iterable, Any, Optional, Tuple, Sequence
import tensorrt as trt
import torch
# Indicate output is in fp16
self.fp16_output = fp16_output
+ # Indices of outputs into the CUDA engine bindings, in the order as they are
+ # in the fx graph's `output` node.
+ self.output_indices_in_order: Sequence[int] = [
+ self.engine.get_binding_index(name) for name in self.output_names
+ ]
+
def _on_state_dict(self, state_dict, prefix, local_metadata):
state_dict[prefix + "engine"] = bytearray(self.engine.serialize())
state_dict[prefix + "input_names"] = self.input_names
# create output tensors
outputs: List[torch.Tensor] = []
- for idx in range(len(inputs), len(inputs) + len(self.output_names)):
+ for idx in self.output_indices_in_order:
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
if self.engine.has_implicit_batch_dimension: