scripted_fn = torch.jit._script_pdt(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
-
- class TestForwardWithNoneType(torch.nn.Module):
- def forward(self, a):
- count = 0
- for i, val in enumerate(a):
- if val is None:
- count += 1
- return count
-
- make_global(TestForwardWithNoneType)
- pdt_model = TestForwardWithNoneType()
-
- # Test List[Optional[float]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2.9, ], )])
- self.assertEqual(scripted_model([2.8, 6.7, 3.8, None, ]), pdt_model([2.8, 6.7, 3.8, None, ]))
-
- # Test Tuple[Optional[int]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((5.1, ), ), ((None, ), ), ])
- self.assertEqual(scripted_model((6.2, None, 10.6, 80.1, None, )), pdt_model((6.2, None, 10.6, 80.1, None, )))
-
- # Test List[Optional[int]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([2, ], )])
- self.assertEqual(scripted_model([2, None, 6, 8, ]), pdt_model([2, None, 6, 8, ]))
-
- # Test Tuple[Optional[int]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )])
- self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, )))
-
- # Test Tuple[Optional[float]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[((None, ), ), ((5, ), )])
- self.assertEqual(scripted_model((2, None, 6, 8)), pdt_model((2, None, 6, 8, )))
-
- # Test Tuple[Optional[torch.Tensor]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[(((torch.ones(1), ), (None, ), ), )])
- self.assertEqual(scripted_model((torch.ones(1), torch.ones(1), None)),
- pdt_model((torch.ones(1), torch.ones(1), None)))
-
- # Test List[Optional[torch.Tensor]] as input
- scripted_model = torch.jit._script_pdt(pdt_model, example_inputs=[([None, ], ), ([torch.ones(1), ], )])
- self.assertEqual(scripted_model([torch.ones(1), torch.ones(1), None]),
- pdt_model([torch.ones(1), torch.ones(1), None]))
import inspect
import typing
import pathlib
-import torch
from typing import Optional, Iterable, List, Dict
from collections import defaultdict
from types import CodeType
except ImportError:
_IS_MONKEYTYPE_INSTALLED = False
-def get_optional_of_element_type(types: str):
+def get_type(type):
+ """
+ Helper function which converts the given type to a torchScript acceptable format.
+ """
+ if isinstance(type, str):
+ return type
+ elif inspect.getmodule(type) == typing:
+ # If the type is a type imported from typing
+ # like Tuple, List, Dict then replace `typing.`
+ # with a null string. This needs to be done since
+ # typing.List is not accepted by TorchScript.
+ type_to_string = str(type)
+ return type_to_string.replace(type.__module__ + '.', '')
+ elif type.__module__.startswith('torch'):
+ # If the type is a subtype of torch module, then TorchScript expects a fully qualified name
+ # for the type which is obtained by combining the module name and type name.
+ return type.__module__ + '.' + type.__name__
+ else:
+ # For all other types use the name for the type.
+ return type.__name__
+
+def get_optional_of_element_type(types):
"""
Helper function to extracts the type of the element to be annotated to Optional
from the list of consolidated types and returns `Optional[element type]`.
-
TODO: To remove this check once Union support lands.
"""
- elements = types.split(",")
- elem_type = elements[0] if 'NoneType' in elements[1] else elements[1]
-
- # If the type is from typing module, then extract the element type
- start = elem_type.find("[")
- end = elem_type.rfind("]")
- if start != -1 and end != -1:
- return elem_type[:start + 1] + 'Optional[' + elem_type[start + 1: end] + ']]'
-
- # Else return Optional[element type]
- if elem_type == 'Tensor':
- elem_type = 'torch.Tensor'
+ elem_type = types[1] if type(None) == types[0] else types[0]
+ elem_type = get_type(elem_type)
+
+ # Optional type is internally converted to Union[type, NoneType], which
+ # is not supported yet in TorchScript. Hence, representing the optional type as string.
return 'Optional[' + elem_type + ']'
def get_qualified_name(func):
# then consolidate the type to `Any` and replace the entry
# by type `Any`.
for arg, types in all_args.items():
- _all_type = " "
- for _type in types:
- # If the type is a type imported from typing
- # like Tuple, List, Dict then replace "typing."
- # with a null string.
- if inspect.getmodule(_type) == typing:
- _type_to_string = str(_type)
- _all_type += _type_to_string.replace('typing.', '') + ','
- elif _type is torch.nn.parameter.Parameter:
- # Check if the type is torch.nn.parameter.Parameter,
- # use the entire quaalified name `torch.nn.parameter.Parameter`
- # for type
- _all_type += 'torch.nn.parameter.Parameter' + ','
- else:
- _all_type += _type.__name__ + ','
- _all_type = _all_type.lstrip(" ") # Remove any trailing spaces
-
- if len(types) == 2 and 'NoneType' in _all_type:
+ types = list(types)
+ type_length = len(types)
+ if type_length == 2 and type(None) in types:
# TODO: To remove this check once Union suppport in TorchScript lands.
- all_args[arg] = {get_optional_of_element_type(_all_type)}
- elif len(types) > 1:
- all_args[arg] = {'Any'}
- else:
- all_args[arg] = {_all_type[:-1]}
+ all_args[arg] = get_optional_of_element_type(types)
+ elif type_length > 1:
+ all_args[arg] = 'Any'
+ elif type_length == 1:
+ all_args[arg] = get_type(types[0])
return all_args
def get_args_types(self, qualified_name: str) -> Dict:
The custom CodeFilter is required while scripting a FX Traced forward calls.
FX Traced forward calls have `code.co_filename` start with '<' which is used
to exclude tracing of stdlib and site-packages in the default code filter.
-
Since we need all forward calls to be traced, this custom code filter
checks for code.co_name to be 'forward' and enables tracing for all such calls.
The code filter is similar to default code filter for monkeytype and
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
# List of Tuple of args and type as inferred by profile directed typing
- arg_and_types = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None)
+ arg_and_types = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg]) else None)
for arg in py_args.args]
- arg_and_types_kwonlyargs = [(arg, next(iter(pdt_arg_types[arg.arg])) if pdt_arg_types and bool(pdt_arg_types[arg.arg])
+ arg_and_types_kwonlyargs = [(arg, pdt_arg_types[arg.arg] if pdt_arg_types and bool(pdt_arg_types[arg.arg])
else None) for arg in py_args.kwonlyargs]
result = [build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type)