+
+import torch
+
import inspect
import typing
import pathlib
+import sys
from typing import Optional, Iterable, List, Dict
from collections import defaultdict
from types import CodeType
except ImportError:
_IS_MONKEYTYPE_INSTALLED = False
+# Checks whether a class is defind in `torch.*` modules
+def is_torch_native_class(cls):
+ if not hasattr(cls, '__module__'):
+ return False
+
+ parent_modules = cls.__module__.split('.')
+ if not parent_modules:
+ return False
+
+ root_module = sys.modules.get(parent_modules[0])
+ return root_module is torch
+
def get_type(type):
"""
Helper function which converts the given type to a torchScript acceptable format.
# typing.List is not accepted by TorchScript.
type_to_string = str(type)
return type_to_string.replace(type.__module__ + '.', '')
- elif type.__module__.startswith('torch'):
+ elif is_torch_native_class(type):
# If the type is a subtype of torch module, then TorchScript expects a fully qualified name
# for the type which is obtained by combining the module name and type name.
return type.__module__ + '.' + type.__name__