Add type annotation for `TRTInterpreter.run` (#65135)
authorKefei Lu <kefeilu@fb.com>
Thu, 16 Sep 2021 20:14:12 +0000 (13:14 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 16 Sep 2021 20:16:06 +0000 (13:16 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65135

Opportunistically adding type annotation as I work through fx2trt code base.

Test Plan: run linter and CI

Reviewed By: houseroad, 842974287

Differential Revision: D30903185

fbshipit-source-id: 3f700b57f4433f2d312c1ff2e6b99948e3c8845c

torch/fx/experimental/fx2trt/fx2trt.py

index af6daa8..513ede2 100644 (file)
@@ -7,6 +7,9 @@ import torch.fx
 from torch.fx.node import _get_qualified_name
 
 
+TRTInterpreterResult = Tuple[Any, Sequence[str], Sequence[str]]
+
+
 # Borrowed from torch2trt
 def torch_dtype_to_trt(dtype):
     if trt.__version__ >= "7.0" and dtype == torch.bool:
@@ -333,7 +336,7 @@ class TRTInterpreter(torch.fx.Interpreter):
         fp16_mode=True,
         int8_mode=False,
         strict_type_constraints=True,
-    ):
+    ) -> TRTInterpreterResult:
         # TODO hack, should check contents of args and remove fp16_mode probably
         self.fp16_mode = fp16_mode