From f1ce64a58e5b68c9a8eb274e271242c60ee5afba Mon Sep 17 00:00:00 2001 From: Samuel Salas Date: Wed, 15 Sep 2021 11:49:28 -0700 Subject: [PATCH] Starter Task 1 (#64927) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64927 Mypy error corrections Test Plan: Corrected mypy errors to make code less prone to bugs by modifying types or adding lines that avoid special undesired cases e.g. asserting a variable to not None. Reviewed By: wushirong Differential Revision: D30901654 fbshipit-source-id: daae8692603b8b38203a98f673c455749c2fb855 --- .../fx2trt/converters/acc_ops_converters.py | 1 - .../experimental/fx2trt/example/fx2trt_example.py | 3 ++- .../fx2trt/example/quantized_resnet_test.py | 4 ++-- torch/fx/experimental/fx_acc/acc_normalizer.py | 21 ++++++++++++--------- torch/fx/experimental/fx_acc/acc_ops.py | 20 ++++++++++++++++---- torch/fx/experimental/fx_acc/acc_tracer.py | 8 ++++---- 6 files changed, 36 insertions(+), 21 deletions(-) diff --git a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py index f555456..47fe26d 100644 --- a/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py +++ b/torch/fx/experimental/fx2trt/converters/acc_ops_converters.py @@ -1,4 +1,3 @@ -# type: ignore[attr-defined] import math import operator diff --git a/torch/fx/experimental/fx2trt/example/fx2trt_example.py b/torch/fx/experimental/fx2trt/example/fx2trt_example.py index 76bf69a..526f02d 100644 --- a/torch/fx/experimental/fx2trt/example/fx2trt_example.py +++ b/torch/fx/experimental/fx2trt/example/fx2trt_example.py @@ -232,7 +232,8 @@ if __name__ == "__main__": """ # We want to lower _run_on_acc_0 to TensorRT. - split_mod._run_on_acc_0 = lower_mod_to_trt(split_mod._run_on_acc_0, (x,)) # type: ignore[arg-type] + assert isinstance(split_mod._run_on_acc_0, torch.fx.GraphModule) + split_mod._run_on_acc_0 = lower_mod_to_trt(split_mod._run_on_acc_0, (x,)) # Assert results are equal with the original model. rn18 = rn18.cuda() diff --git a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py index 415ce2d..982cff7 100644 --- a/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py +++ b/torch/fx/experimental/fx2trt/example/quantized_resnet_test.py @@ -11,7 +11,7 @@ rn18 = models.resnet18().eval() def build_fp16_trt(rn18): rn18 = copy.deepcopy(rn18) - rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) # type: ignore[attr-defined] + rn18 = acc_tracer.trace(rn18, [torch.randn(1, 3, 224, 224)]) interp = TRTInterpreter( rn18, [InputTensorSpec(torch.Size([3, 224, 224]), torch.float, has_batch_dim=False)]) engine, input_names, output_names = interp.run(fp16_mode=True) @@ -37,7 +37,7 @@ def build_int8_trt(rn18): ref_res = quantized_rn18(data) print("quantized model:", quantized_rn18) - quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[attr-defined] + quantized_rn18 = acc_tracer.trace(quantized_rn18, [data]) # type: ignore[assignment] interp = TRTInterpreter( quantized_rn18, [InputTensorSpec(torch.Size([-1, *data.shape[1:]]), torch.float, diff --git a/torch/fx/experimental/fx_acc/acc_normalizer.py b/torch/fx/experimental/fx_acc/acc_normalizer.py index 66e83a5..1c61740 100644 --- a/torch/fx/experimental/fx_acc/acc_normalizer.py +++ b/torch/fx/experimental/fx_acc/acc_normalizer.py @@ -1,4 +1,3 @@ -# type: ignore[] import inspect import re from typing import NamedTuple, Optional, Callable, Dict, List, Tuple, Union, Any, Set @@ -114,7 +113,7 @@ def _insert_fun( assert op_and_target not in _normalization_dict.keys() norm_info = NormalizationInfo( - new_fn_target=new_fn_target, + new_fn_target=new_fn_target, # type: ignore[arg-type] arg_replacement_tuples=final_arg_replacement_tuples, custom_mapping_fn=custom_mapping_fn, kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, @@ -128,8 +127,8 @@ def _insert_fun( # "" in order to allow for whatever mangling index is used. if allow_normalize_from_torch_package: torch_package_op_and_target = ( - op_and_target[0], - f".{_get_qualified_name(op_and_target[1])}", + op_and_target[0], # type: ignore[] + f".{_get_qualified_name(op_and_target[1])}", # type: ignore[arg-type] ) _normalization_dict[torch_package_op_and_target] = norm_info @@ -156,7 +155,7 @@ def register_acc_op(acc_op: Callable): def register_acc_op_mapping( op_and_target: Tuple[str, Union[str, Callable]], arg_replacement_tuples: Optional[ - List[Tuple[Union[str, Tuple[str, ...]], str]] + List[Union[Tuple[Union[str, Tuple[str, ...]], str], Tuple[Union[str, Tuple[str, ...]], str, bool]]] ] = None, kwargs_to_move_to_acc_out_ty: Optional[List[Tuple[str, str]]] = None, ): @@ -175,12 +174,12 @@ def register_acc_op_mapping( if arg_replacement_tuples is None: final_arg_replacement_tuples = _get_dup_signature_tuples(new_fn_target) else: - final_arg_replacement_tuples = arg_replacement_tuples + final_arg_replacement_tuples = arg_replacement_tuples # type: ignore[assignment] _insert_fun( op_and_target=op_and_target, new_fn_target=new_fn_target, - arg_replacement_tuples=final_arg_replacement_tuples, + arg_replacement_tuples=final_arg_replacement_tuples, # type: ignore[arg-type] kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, ) return new_fn_target @@ -190,7 +189,7 @@ def register_acc_op_mapping( def register_custom_acc_mapper_fn( op_and_target: Tuple[str, Union[str, Callable]], - arg_replacement_tuples: List[Tuple[Union[str, Tuple[str, ...]], str]], + arg_replacement_tuples: List[Union[Tuple[Union[str, Tuple[str, ...]], str], Tuple[Union[str, Tuple[str, ...]], str, bool]]], needs_shapes_for_normalization=False, allow_normalize_from_torch_package=False, ): @@ -198,7 +197,7 @@ def register_custom_acc_mapper_fn( _insert_fun( op_and_target=op_and_target, custom_mapping_fn=custom_mapping_fn, - arg_replacement_tuples=arg_replacement_tuples, + arg_replacement_tuples=arg_replacement_tuples, # type: ignore[arg-type] needs_shapes_for_normalization=needs_shapes_for_normalization, allow_normalize_from_torch_package=allow_normalize_from_torch_package, ) @@ -216,12 +215,15 @@ def move_kwargs_to_acc_out_ty( a node to fetch NormalizationInfo for, check if kwargs_to_move_to_acc_out_ty exists in the NormalizationInfo, and if so perform the move of kwargs to acc_out_ty. """ + if isinstance(node_or_normalization_info, torch.fx.Node): node = node_or_normalization_info normalization_info = _normalization_dict.get((node.op, node.target)) else: + assert isinstance(node_or_normalization_info, NormalizationInfo) normalization_info = node_or_normalization_info + assert normalization_info is not None if normalization_info.kwargs_to_move_to_acc_out_ty is None: return @@ -367,6 +369,7 @@ def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False) # Get the normalized kwargs to be used by normalize_to_acc_op below. If # normalization_info.arg_replacement_tuples is empty then assume the function # signature must be left as is. + assert normalization_info.arg_replacement_tuples is not None if len(normalization_info.arg_replacement_tuples) == 0: normalized_args = node.args normalized_kwargs = node.kwargs diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py index 1aa0d42..3d19a2c 100644 --- a/torch/fx/experimental/fx_acc/acc_ops.py +++ b/torch/fx/experimental/fx_acc/acc_ops.py @@ -1,5 +1,4 @@ # encoding: utf-8 -# type: ignore[] import operator import torch # isort:skip @@ -12,6 +11,8 @@ from torch.fx.experimental.fx_acc.acc_normalizer import ( ) from torch.fx.passes.shape_prop import _extract_tensor_metadata +from typing import Sequence, Optional, List + this_arg_is_optional = True @@ -119,6 +120,7 @@ def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: # Have to use args here since getattr forces positional args. input_obj = node.args[0] attr_name = node.args[1] + assert isinstance(input_obj, torch.fx.Node) assert ( input_obj.meta["type"] == torch.Tensor ), f"Expected torch.Tensor type for {input_obj.meta['type']}" @@ -189,6 +191,7 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: with node.graph.inserting_before(node): inputs = node.kwargs["tensors"] unsqueeze_nodes = [] + assert isinstance(inputs, Sequence) for i, t in enumerate(inputs): new_node = node.graph.create_node( "call_function", @@ -279,6 +282,7 @@ def addmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: new_input_node = node.graph.create_node( "call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul" ) + assert isinstance(input_node, torch.fx.Node) new_input_node.meta = input_node.meta.copy() input_node = new_input_node @@ -863,9 +867,11 @@ def torch_split_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: new_node.meta = node.meta.copy() return new_node + assert isinstance(split_size_or_sections, Sequence) start = 0 slice_nodes = [] for i in split_size_or_sections: + assert isinstance(i, int) new_kwargs = { "input": node.kwargs["input"], "dims": (node.kwargs["dim"],), @@ -1022,7 +1028,7 @@ def getitem(*, input, idx): @register_acc_op def slice_tensor(*, input, dims, starts, stops, steps): - slices = [None for _ in range(input.dim())] + slices: List[Optional[slice]] = [None for _ in range(input.dim())] # For all provided dims, extract out a slice for starts/stops/steps. for idx, dim in enumerate(dims): @@ -1055,6 +1061,7 @@ def slice_tensor(*, input, dims, starts, stops, steps): ], ) def custom_narrow_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + assert isinstance(node.kwargs["start"], int) and isinstance(node.kwargs["length"], int) kwargs = { "input": node.kwargs["input"], "dims": (node.kwargs["dim"],), @@ -1109,8 +1116,9 @@ def custom_tensor_reshape_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx. input_node = node.kwargs["input"] shape = node.kwargs["shape"] - if isinstance(shape[0], (tuple, list)): - shape = shape[0] + assert isinstance(shape, Sequence) + if isinstance(shape[0], (tuple, list)): # type: ignore[index] + shape = shape[0] # type: ignore[index] with node.graph.inserting_before(node): new_node = node.graph.call_function( @@ -1214,6 +1222,7 @@ def packed_quantized_linear_mapper( Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias in this mapper and pass them directly to linear node. """ + assert isinstance(node.target, str) linear_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" @@ -1262,6 +1271,7 @@ def packed_quantized_conv2d_mapper( Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters in this mapper and pass them directly to conv2d node. """ + assert isinstance(node.target, str) conv_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" @@ -1371,6 +1381,7 @@ def packed_quantized_convrelu2d_mapper( ] ) def custom_dequantize_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: + assert isinstance(node.kwargs["input"], torch.fx.Node) assert "tensor_meta" in node.kwargs["input"].meta new_kwargs = {"input": node.kwargs["input"], "input_tensor_meta": node.kwargs["input"].meta["tensor_meta"]} # `input_tensor_meta` contains quantization parameters that can be used to lower @@ -1379,5 +1390,6 @@ def custom_dequantize_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.No new_node = node.graph.create_node( "call_function", dequantize, kwargs=new_kwargs, name=node.name ) + assert isinstance(node, torch.fx.Node) new_node.meta = node.meta return new_node diff --git a/torch/fx/experimental/fx_acc/acc_tracer.py b/torch/fx/experimental/fx_acc/acc_tracer.py index 21fade9..ec8d693 100644 --- a/torch/fx/experimental/fx_acc/acc_tracer.py +++ b/torch/fx/experimental/fx_acc/acc_tracer.py @@ -1,4 +1,3 @@ -# type: ignore[] import ast import builtins import copy @@ -111,7 +110,7 @@ class Acc_Rewriter(ast.NodeTransformer): exc_msg = _reuse_loc(ast.Constant(None)) elif isinstance(node_for_exc, ast.Call): # E.g. `raise AssertionError("error message")` - name_node_of_exc = node_for_exc.func + name_node_of_exc = node_for_exc.func # type: ignore[assignment] if not isinstance(name_node_of_exc, ast.Name): return if_node # Most assertions just take a single string arg, but some may not; skip @@ -251,7 +250,7 @@ def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn. # functions that are attrs of this moodule. Return the new, rewritten module # hierarchy. def rewrite_module(m: nn.Module): - base_class = type(m) + base_class : Type[nn.Module] = type(m) # Keep track of all the ConditionalExceptionWrappers that the # Acc_Rewriter calls into in this module so we can add them in init @@ -259,7 +258,7 @@ def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn. all_added_wrappers: Set[Type[Exception]] = set() # Note: Make this a subclass of our base class. - class RewrittenModule(base_class): + class RewrittenModule(base_class): # type: ignore[valid-type, misc] # Keep track of the base_class so that symbolic tracing can # determine what kind of module this originally was later on. _base_class_origin = base_class @@ -279,6 +278,7 @@ def _rewrite(mod_to_rewrite: nn.Module, allow_list: Optional[Set] = None) -> nn. continue # Only rewrite those Modules explicitly in the allow_list. + assert allow_list is not None if base_class not in allow_list: vars()[method_name] = method else: -- 2.7.4