From: gmagogsfm Date: Fri, 27 Aug 2021 15:49:54 +0000 (-0700) Subject: More robust check of whether a class is defined in torch (#64083) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~655 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ad8eddbd808a97ac518ffd5b51d2c925803a1a3f;p=platform%2Fupstream%2Fpytorch.git More robust check of whether a class is defined in torch (#64083) Summary: This would prevent bugs for classes that 1) Is defined in a module that happens to start with `torch`, say `torchvision` 2) Is defined in torch but with an import alias like `import torch as th` Pull Request resolved: https://github.com/pytorch/pytorch/pull/64083 Reviewed By: soulitzer Differential Revision: D30598369 Pulled By: gmagogsfm fbshipit-source-id: 9d3a7135737b2339c9bd32195e4e69a9c07549d4 --- diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index f0e4613..9957541 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -1,6 +1,10 @@ + +import torch + import inspect import typing import pathlib +import sys from typing import Optional, Iterable, List, Dict from collections import defaultdict from types import CodeType @@ -15,6 +19,18 @@ try: except ImportError: _IS_MONKEYTYPE_INSTALLED = False +# Checks whether a class is defind in `torch.*` modules +def is_torch_native_class(cls): + if not hasattr(cls, '__module__'): + return False + + parent_modules = cls.__module__.split('.') + if not parent_modules: + return False + + root_module = sys.modules.get(parent_modules[0]) + return root_module is torch + def get_type(type): """ Helper function which converts the given type to a torchScript acceptable format. @@ -28,7 +44,7 @@ def get_type(type): # typing.List is not accepted by TorchScript. type_to_string = str(type) return type_to_string.replace(type.__module__ + '.', '') - elif type.__module__.startswith('torch'): + elif is_torch_native_class(type): # If the type is a subtype of torch module, then TorchScript expects a fully qualified name # for the type which is obtained by combining the module name and type name. return type.__module__ + '.' + type.__name__