More robust check of whether a class is defined in torch (#64083)
authorgmagogsfm <gmagogsfm@gmail.com>
Fri, 27 Aug 2021 15:49:54 +0000 (08:49 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 15:55:35 +0000 (08:55 -0700)
Summary:
This would prevent bugs for classes that
1) Is defined in a module that happens to start with `torch`, say `torchvision`
2) Is defined in torch but with an import alias like `import torch as th`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64083

Reviewed By: soulitzer

Differential Revision: D30598369

Pulled By: gmagogsfm

fbshipit-source-id: 9d3a7135737b2339c9bd32195e4e69a9c07549d4

torch/jit/_monkeytype_config.py

index f0e4613..9957541 100644 (file)
@@ -1,6 +1,10 @@
+
+import torch
+
 import inspect
 import typing
 import pathlib
+import sys
 from typing import Optional, Iterable, List, Dict
 from collections import defaultdict
 from types import CodeType
@@ -15,6 +19,18 @@ try:
 except ImportError:
     _IS_MONKEYTYPE_INSTALLED = False
 
+# Checks whether a class is defind in `torch.*` modules
+def is_torch_native_class(cls):
+    if not hasattr(cls, '__module__'):
+        return False
+
+    parent_modules = cls.__module__.split('.')
+    if not parent_modules:
+        return False
+
+    root_module = sys.modules.get(parent_modules[0])
+    return root_module is torch
+
 def get_type(type):
     """
     Helper function which converts the given type to a torchScript acceptable format.
@@ -28,7 +44,7 @@ def get_type(type):
         # typing.List is not accepted by TorchScript.
         type_to_string = str(type)
         return type_to_string.replace(type.__module__ + '.', '')
-    elif type.__module__.startswith('torch'):
+    elif is_torch_native_class(type):
         # If the type is a subtype of torch module, then TorchScript expects a fully qualified name
         # for the type which is obtained by combining the module name and type name.
         return type.__module__ + '.' + type.__name__