Clean up related to type refinements (#62444)
authornikithamalgi <nikithamalgi@devvm146.prn0.facebook.com>
Thu, 26 Aug 2021 04:47:50 +0000 (21:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 04:53:00 +0000 (21:53 -0700)
Summary:
Creates a helper function to refine the types into a torchScript compatible format in the monkeytype config for profile directed typing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62444

Reviewed By: malfet

Differential Revision: D30548159

Pulled By: nikithamalgifb

fbshipit-source-id: 7c09ce5f5e043d069313b87112837d7e226ade1f

test/jit/test_pdt.py
torch/jit/_monkeytype_config.py
torch/jit/frontend.py

index b04a66e..57cd74f 100644 (file)
@@ -454,44 +454,3 @@ class TestPDT(JitTestCase):
 
         scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
         self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
-
-        class TestForwardWithNoneType(torch.nn.Module):
-            def forward(self, a):
-                count = 0
-                for i, val in enumerate(a):
-                    if val is None:
-                        count += 1
-                return count
-
-        make_global(TestForwardWithNoneType)
-        pdt_model = TestForwardWithNoneType()
-
-        # Test List[Optional[float]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2.9, ], )])
-        self.assertEqual(scripted_model([2.8, 6.7, 3.8, None, ]), pdt_model([2.8, 6.7, 3.8, None, ]))
-
-        # Test Tuple[Optional[int]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((5.1, ), ), ((None, ), ), ])
-        self.assertEqual(scripted_model((6.2, None, 10.6, 80.1, None, )), pdt_model((6.2, None, 10.6, 80.1, None, )))
-
-        # Test List[Optional[int]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2, ], )])
-        self.assertEqual(scripted_model([2, None, 6, 8, ]), pdt_model([2, None, 6, 8, ]))
-
-        # Test Tuple[Optional[int]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )])
-        self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, )))
-
-        # Test Tuple[Optional[float]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )])
-        self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, )))
-
-        # Test Tuple[Optional[torch.Tensor]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[(((torch.ones(1), ), (None, ), ), )])
-        self.assertEqual(scripted_model((torch.ones(1), torch.ones(1), None)),
-                         pdt_model((torch.ones(1), torch.ones(1), None)))
-
-        # Test List[Optional[torch.Tensor]] as input
-        scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([torch.ones(1), ], )])
-        self.assertEqual(scripted_model([torch.ones(1), torch.ones(1), None]),
-                         pdt_model([torch.ones(1), torch.ones(1), None]))
index b5a698e..f0e4613 100644 (file)
@@ -1,7 +1,6 @@
 import inspect
 import typing
 import pathlib
-import torch
 from typing import Optional, Iterable, List, Dict
 from collections import defaultdict
 from types import CodeType
@@ -16,25 +15,38 @@ try:
 except ImportError:
     _IS_MONKEYTYPE_INSTALLED = False
 
-def get_optional_of_element_type(types: str):
+def get_type(type):
+    """
+    Helper function which converts the given type to a torchScript acceptable format.
+    """
+    if isinstance(type, str):
+        return type
+    elif inspect.getmodule(type) == typing:
+        # If the type is a type imported from typing
+        # like Tuple, List, Dict then replace `typing.`
+        # with a null string. This needs to be done since
+        # typing.List is not accepted by TorchScript.
+        type_to_string = str(type)
+        return type_to_string.replace(type.__module__ + '.', '')
+    elif type.__module__.startswith('torch'):
+        # If the type is a subtype of torch module, then TorchScript expects a fully qualified name
+        # for the type which is obtained by combining the module name and type name.
+        return type.__module__ + '.' + type.__name__
+    else:
+        # For all other types use the name for the type.
+        return type.__name__
+
+def get_optional_of_element_type(types):
     """
     Helper function to extracts the type of the element to be annotated to Optional
     from the list of consolidated types and returns `Optional[element type]`.
-
     TODO: To remove this check once Union support lands.
     """
-    elements = types.split(",")
-    elem_type = elements[0] if 'NoneType' in elements[1] else elements[1]
-
-    # If the type is from typing module, then extract the element type
-    start = elem_type.find("[")
-    end = elem_type.rfind("]")
-    if start != -1 and end != -1:
-        return elem_type[:start + 1] + 'Optional[' + elem_type[start + 1: end] + ']]'
-
-    # Else return Optional[element type]
-    if elem_type == 'Tensor':
-        elem_type = 'torch.Tensor'
+    elem_type = types[1] if type(None) == types[0] else types[0]
+    elem_type = get_type(elem_type)
+
+    # Optional type is internally converted to Union[type, NoneType], which
+    # is not supported yet in TorchScript. Hence, representing the optional type as string.
     return 'Optional[' + elem_type + ']'
 
 def get_qualified_name(func):
@@ -88,30 +100,15 @@ if _IS_MONKEYTYPE_INSTALLED:
             # then consolidate the type to `Any` and replace the entry
             # by type `Any`.
             for arg, types in all_args.items():
-                _all_type = " "
-                for _type in types:
-                    # If the type is a type imported from typing
-                    # like Tuple, List, Dict then replace "typing."
-                    # with a null string.
-                    if inspect.getmodule(_type) == typing:
-                        _type_to_string = str(_type)
-                        _all_type += _type_to_string.replace('typing.', '') + ','
-                    elif _type is torch.nn.parameter.Parameter:
-                        # Check if the type is torch.nn.parameter.Parameter,
-                        # use the entire quaalified name `torch.nn.parameter.Parameter`
-                        # for type
-                        _all_type += 'torch.nn.parameter.Parameter' + ','
-                    else:
-                        _all_type += _type.__name__ + ','
-                _all_type = _all_type.lstrip(" ")  # Remove any trailing spaces
-
-                if len(types) == 2 and 'NoneType' in _all_type:
+                types = list(types)
+                type_length = len(types)
+                if type_length == 2 and type(None) in types:
                     # TODO: To remove this check once Union suppport in TorchScript lands.
-                    all_args[arg] = {get_optional_of_element_type(_all_type)}
-                elif len(types) > 1:
-                    all_args[arg] = {'Any'}
-                else:
-                    all_args[arg] = {_all_type[:-1]}
+                    all_args[arg] = get_optional_of_element_type(types)
+                elif type_length > 1:
+                    all_args[arg] = 'Any'
+                elif type_length == 1:
+                    all_args[arg] = get_type(types[0])
             return all_args
 
         def get_args_types(self, qualified_name: str) -> Dict:
@@ -157,7 +154,6 @@ def jit_code_filter(code: CodeType) -> bool:
     The custom CodeFilter is required while scripting a FX Traced forward calls.
     FX Traced forward calls have `code.co_filename` start with '<' which is used
     to exclude tracing of stdlib and site-packages in the default code filter.
-
     Since we need all forward calls to be traced, this custom code filter
     checks for code.co_name to be 'forward' and enables tracing for all such calls.
     The code filter is similar to default code filter for monkeytype and
index b0228b1..0928106 100644 (file)
@@ -337,9 +337,9 @@ def build_param_list(ctx, py_args, self_name, pdt_arg_types=None):
                 raise NotSupportedError(ctx_range, _vararg_kwarg_err)
 
     # List of Tuple of args and type as inferred by profile directed typing
-    arg_and_types = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None)
+    arg_and_types = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None)
                      for arg in py_args.args]
-    arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg])
+    arg_and_types_kwonlyargs = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg])
                                 else None) for arg in py_args.kwonlyargs]
 
     result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type)