From 49c8fbc92f70d6d78e02e2b7944de59d9348db37 Mon Sep 17 00:00:00 2001 From: nikithamalgi Date: Wed, 25 Aug 2021 21:47:50 -0700 Subject: [PATCH] Clean up related to type refinements (#62444) Summary: Creates a helper function to refine the types into a torchScript compatible format in the monkeytype config for profile directed typing Pull Request resolved: https://github.com/pytorch/pytorch/pull/62444 Reviewed By: malfet Differential Revision: D30548159 Pulled By: nikithamalgifb fbshipit-source-id: 7c09ce5f5e043d069313b87112837d7e226ade1f --- test/jit/test_pdt.py | 41 ----------------------- torch/jit/_monkeytype_config.py | 74 +++++++++++++++++++---------------------- torch/jit/frontend.py | 4 +-- 3 files changed, 37 insertions(+), 82 deletions(-) diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index b04a66e..57cd74f 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -454,44 +454,3 @@ class TestPDT(JitTestCase): scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )]) self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), )) - - class TestForwardWithNoneType(torch.nn.Module): - def forward(self, a): - count = 0 - for i, val in enumerate(a): - if val is None: - count += 1 - return count - - make_global(TestForwardWithNoneType) - pdt_model = TestForwardWithNoneType() - - # Test List[Optional[float]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2.9, ], )]) - self.assertEqual(scripted_model([2.8, 6.7, 3.8, None, ]), pdt_model([2.8, 6.7, 3.8, None, ])) - - # Test Tuple[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((5.1, ), ), ((None, ), ), ]) - self.assertEqual(scripted_model((6.2, None, 10.6, 80.1, None, )), pdt_model((6.2, None, 10.6, 80.1, None, ))) - - # Test List[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2, ], )]) - self.assertEqual(scripted_model([2, None, 6, 8, ]), pdt_model([2, None, 6, 8, ])) - - # Test Tuple[Optional[int]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )]) - self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, ))) - - # Test Tuple[Optional[float]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )]) - self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, ))) - - # Test Tuple[Optional[torch.Tensor]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[(((torch.ones(1), ), (None, ), ), )]) - self.assertEqual(scripted_model((torch.ones(1), torch.ones(1), None)), - pdt_model((torch.ones(1), torch.ones(1), None))) - - # Test List[Optional[torch.Tensor]] as input - scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([torch.ones(1), ], )]) - self.assertEqual(scripted_model([torch.ones(1), torch.ones(1), None]), - pdt_model([torch.ones(1), torch.ones(1), None])) diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index b5a698e..f0e4613 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,7 +1,6 @@ import inspect import typing import pathlib -import torch from typing import Optional, Iterable, List, Dict from collections import defaultdict from types import CodeType @@ -16,25 +15,38 @@ try: except ImportError: _IS_MONKEYTYPE_INSTALLED = False -def get_optional_of_element_type(types: str): +def get_type(type): + """ + Helper function which converts the given type to a torchScript acceptable format. + """ + if isinstance(type, str): + return type + elif inspect.getmodule(type) == typing: + # If the type is a type imported from typing + # like Tuple, List, Dict then replace `typing.` + # with a null string. This needs to be done since + # typing.List is not accepted by TorchScript. + type_to_string = str(type) + return type_to_string.replace(type.__module__ + '.', '') + elif type.__module__.startswith('torch'): + # If the type is a subtype of torch module, then TorchScript expects a fully qualified name + # for the type which is obtained by combining the module name and type name. + return type.__module__ + '.' + type.__name__ + else: + # For all other types use the name for the type. + return type.__name__ + +def get_optional_of_element_type(types): """ Helper function to extracts the type of the element to be annotated to Optional from the list of consolidated types and returns `Optional[element type]`. - TODO: To remove this check once Union support lands. """ - elements = types.split(",") - elem_type = elements[0] if 'NoneType' in elements[1] else elements[1] - - # If the type is from typing module, then extract the element type - start = elem_type.find("[") - end = elem_type.rfind("]") - if start != -1 and end != -1: - return elem_type[:start + 1] + 'Optional[' + elem_type[start + 1: end] + ']]' - - # Else return Optional[element type] - if elem_type == 'Tensor': - elem_type = 'torch.Tensor' + elem_type = types[1] if type(None) == types[0] else types[0] + elem_type = get_type(elem_type) + + # Optional type is internally converted to Union[type, NoneType], which + # is not supported yet in TorchScript. Hence, representing the optional type as string. return 'Optional[' + elem_type + ']' def get_qualified_name(func): @@ -88,30 +100,15 @@ if _IS_MONKEYTYPE_INSTALLED: # then consolidate the type to `Any` and replace the entry # by type `Any`. for arg, types in all_args.items(): - _all_type = " " - for _type in types: - # If the type is a type imported from typing - # like Tuple, List, Dict then replace "typing." - # with a null string. - if inspect.getmodule(_type) == typing: - _type_to_string = str(_type) - _all_type += _type_to_string.replace('typing.', '') + ',' - elif _type is torch.nn.parameter.Parameter: - # Check if the type is torch.nn.parameter.Parameter, - # use the entire quaalified name `torch.nn.parameter.Parameter` - # for type - _all_type += 'torch.nn.parameter.Parameter' + ',' - else: - _all_type += _type.__name__ + ',' - _all_type = _all_type.lstrip(" ") # Remove any trailing spaces - - if len(types) == 2 and 'NoneType' in _all_type: + types = list(types) + type_length = len(types) + if type_length == 2 and type(None) in types: # TODO: To remove this check once Union suppport in TorchScript lands. - all_args[arg] = {get_optional_of_element_type(_all_type)} - elif len(types) > 1: - all_args[arg] = {'Any'} - else: - all_args[arg] = {_all_type[:-1]} + all_args[arg] = get_optional_of_element_type(types) + elif type_length > 1: + all_args[arg] = 'Any' + elif type_length == 1: + all_args[arg] = get_type(types[0]) return all_args def get_args_types(self, qualified_name: str) -> Dict: @@ -157,7 +154,6 @@ def jit_code_filter(code: CodeType) -> bool: The custom CodeFilter is required while scripting a FX Traced forward calls. FX Traced forward calls have `code.co_filename` start with '<' which is used to exclude tracing of stdlib and site-packages in the default code filter. - Since we need all forward calls to be traced, this custom code filter checks for code.co_name to be 'forward' and enables tracing for all such calls. The code filter is similar to default code filter for monkeytype and diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index b0228b1..0928106 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -337,9 +337,9 @@ def build_param_list(ctx, py_args, self_name, pdt_arg_types=None): raise NotSupportedError(ctx_range, _vararg_kwarg_err) # List of Tuple of args and type as inferred by profile directed typing - arg_and_types = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) + arg_and_types = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) for arg in py_args.args] - arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) + arg_and_types_kwonlyargs = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None) for arg in py_args.kwonlyargs] result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type) -- 2.7.4