Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65135
Opportunistically adding type annotation as I work through fx2trt code base.
Test Plan: run linter and CI
Reviewed By: houseroad,
842974287
Differential Revision:
D30903185
fbshipit-source-id:
3f700b57f4433f2d312c1ff2e6b99948e3c8845c
from torch.fx.node import _get_qualified_name
+TRTInterpreterResult = Tuple[Any, Sequence[str], Sequence[str]]
+
+
# Borrowed from torch2trt
def torch_dtype_to_trt(dtype):
if trt.__version__ >= "7.0" and dtype == torch.bool:
fp16_mode=True,
int8_mode=False,
strict_type_constraints=True,
- ):
+ ) -> TRTInterpreterResult:
# TODO hack, should check contents of args and remove fp16_mode probably
self.fp16_mode = fp16_mode