From: Kefei Lu Date: Thu, 16 Sep 2021 20:14:12 +0000 (-0700) Subject: Add type annotation for `TRTInterpreter.run` (#65135) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~144 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f9c341fdf2a8b90017ad696b89cebbacf0c8fb95;p=platform%2Fupstream%2Fpytorch.git Add type annotation for `TRTInterpreter.run` (#65135) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65135 Opportunistically adding type annotation as I work through fx2trt code base. Test Plan: run linter and CI Reviewed By: houseroad, 842974287 Differential Revision: D30903185 fbshipit-source-id: 3f700b57f4433f2d312c1ff2e6b99948e3c8845c --- diff --git a/torch/fx/experimental/fx2trt/fx2trt.py b/torch/fx/experimental/fx2trt/fx2trt.py index af6daa8..513ede2 100644 --- a/torch/fx/experimental/fx2trt/fx2trt.py +++ b/torch/fx/experimental/fx2trt/fx2trt.py @@ -7,6 +7,9 @@ import torch.fx from torch.fx.node import _get_qualified_name +TRTInterpreterResult = Tuple[Any, Sequence[str], Sequence[str]] + + # Borrowed from torch2trt def torch_dtype_to_trt(dtype): if trt.__version__ >= "7.0" and dtype == torch.bool: @@ -333,7 +336,7 @@ class TRTInterpreter(torch.fx.Interpreter): fp16_mode=True, int8_mode=False, strict_type_constraints=True, - ): + ) -> TRTInterpreterResult: # TODO hack, should check contents of args and remove fp16_mode probably self.fp16_mode = fp16_mode