Refactor _jit_internal (#16058)
authorDavid Riazati <davidriazati@fb.com>
Thu, 17 Jan 2019 21:39:07 +0000 (13:39 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 21:56:50 +0000 (13:56 -0800)
Summary:
Use qualified names in `jit/__init__.py` to avoid polluting that namespace
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16058

Differential Revision: D13718745

Pulled By: driazati

fbshipit-source-id: 19d150569c8374541250a961f24f70c3f523de03

test/test_jit.py
torch/_jit_internal.py
torch/jit/__init__.py

index a635041..2afad2a 100644 (file)
@@ -11690,7 +11690,7 @@ def add_nn_module_test(*args, **kwargs):
     module_name = name.split("_")[0]
 
     module = getattr(torch.nn, module_name, None)
-    if module is None or torch._jit_internal._weak_types.get(module) is None:
+    if module is None or torch._jit_internal.weak_types.get(module) is None:
         return
 
     if 'desc' in kwargs and 'eval' in kwargs['desc']:
index 45800c8..2ca6d1a 100644 (file)
@@ -9,20 +9,20 @@ import inspect
 from torch._six import builtins
 
 # Tracks standalone weak script functions
-_compiled_weak_fns = weakref.WeakKeyDictionary()
+compiled_weak_fns = weakref.WeakKeyDictionary()
 
 # Tracks which methods should be converted to strong methods
-_weak_script_methods = weakref.WeakKeyDictionary()
+weak_script_methods = weakref.WeakKeyDictionary()
 
 # Converted modules and their corresponding WeakScriptModuleProxy objects
-_weak_modules = weakref.WeakKeyDictionary()
+weak_modules = weakref.WeakKeyDictionary()
 
 # Types that have been declared as weak modules
-_weak_types = weakref.WeakKeyDictionary()
+weak_types = weakref.WeakKeyDictionary()
 
 # Wrapper functions that can call either of 2 functions depending on a boolean
 # argument
-_boolean_dispatched = weakref.WeakKeyDictionary()
+boolean_dispatched = weakref.WeakKeyDictionary()
 
 COMPILATION_PENDING = object()
 COMPILED = object()
@@ -84,7 +84,7 @@ def weak_script(fn, _frames_up=0):
     inlined in the graph. When not used in a script function, the weak script
     annotation has no effect.
     """
-    _compiled_weak_fns[fn] = {
+    compiled_weak_fns[fn] = {
         "status": COMPILATION_PENDING,
         "compiled_fn": None,
         "rcb": createResolutionCallback(_frames_up + 1)
@@ -93,14 +93,14 @@ def weak_script(fn, _frames_up=0):
 
 
 def weak_module(cls):
-    _weak_types[cls] = {
+    weak_types[cls] = {
         "method_stubs": None
     }
     return cls
 
 
 def weak_script_method(fn):
-    _weak_script_methods[fn] = {
+    weak_script_methods[fn] = {
         "rcb": createResolutionCallback(frames_up=2),
         "original_method": fn
     }
@@ -113,7 +113,7 @@ def boolean_dispatch(arg_name, arg_index, default, if_true, if_false):
     In TorchScript, the boolean argument must be constant so that the correct
     function to use can be determined at compile time.
     """
-    if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None:
+    if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
         raise RuntimeError("both functions must be weak script")
 
     def fn(*args, **kwargs):
@@ -141,7 +141,7 @@ def boolean_dispatch(arg_name, arg_index, default, if_true, if_false):
         raise RuntimeError("only one function can have a docstring")
     fn.__doc__ = doc
 
-    _boolean_dispatched[fn] = {
+    boolean_dispatched[fn] = {
         "if_true": if_true,
         "if_false": if_false,
         "index": arg_index,
index fd91f2e..317fe7e 100644 (file)
@@ -6,11 +6,9 @@ from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
 from torch.jit.frontend import get_jit_ast, get_default_args
 import torch.backends.cudnn as cudnn
 import torch.jit.annotations
+import torch._jit_internal as _jit_internal
 from torch._six import raise_from, with_metaclass, get_function_from_type, \
     string_classes
-from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
-    _weak_script_methods, _weak_modules, _weak_types, COMPILED, \
-    COMPILATION_PENDING, _boolean_dispatched
 from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
     _list_with_default
 import torch.testing
@@ -656,7 +654,7 @@ class CompilationUnit(object):
 
     def define(self, lang, rcb=None, _frames_up=0):
         if not rcb:
-            rcb = createResolutionCallback(_frames_up + 1)
+            rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
         self.module._define(lang, rcb, False)
 
     def __getattr__(self, attr):
@@ -666,18 +664,18 @@ class CompilationUnit(object):
 def _try_get_dispatched_fn(fn):
     if not callable(fn):
         return None
-    return _boolean_dispatched.get(fn)
+    return _jit_internal.boolean_dispatched.get(fn)
 
 
 def _try_compile_weak_script(fn):
-    entry = _compiled_weak_fns.get(fn)
+    entry = _jit_internal.compiled_weak_fns.get(fn)
     if entry is None:
         return None
-    if entry["status"] == COMPILATION_PENDING:
+    if entry["status"] == _jit_internal.COMPILATION_PENDING:
         compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
         del entry["rcb"]
-        _compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
-        entry["status"] = COMPILED
+        _jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
+        entry["status"] = _jit_internal.COMPILED
         return compiled_fn
     else:
         return entry["compiled_fn"]
@@ -687,7 +685,7 @@ def script(fn, optimize=True, _frames_up=0, _rcb=None):
     if not _enabled:
         return fn
     if _rcb is None:
-        _rcb = createResolutionCallback(_frames_up + 1)
+        _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
     ast = get_jit_ast(fn, is_method=False)
     mod = ScriptModule()
     _jit_script_compile(mod, ast, _rcb, get_default_args(fn))
@@ -714,7 +712,7 @@ def script_method(fn, _rcb=None):
     # createResolutionCallback internally adds 1 to get us to the scope of this
     # function (the calling function). Adding 2 gets us to the proper surrounding scope.
     if _rcb is None:
-        _rcb = createResolutionCallback(frames_up=2)
+        _rcb = _jit_internal.createResolutionCallback(frames_up=2)
     ast = get_jit_ast(fn, is_method=True)
     return ScriptMethodStub(_rcb, ast, fn)
 
@@ -725,14 +723,14 @@ def _try_get_weak_module(mod):
     """
     if not isinstance(mod, Module):
         return None
-    return _weak_modules.get(mod)
+    return _jit_internal.weak_modules.get(mod)
 
 
 def _is_weak_type(cls):
     """
     Check if a type has been annotated with `weak_module`
     """
-    return cls in _weak_types
+    return cls in _jit_internal.weak_types
 
 
 def batch(batch_size=1, optimize=True, _frames_up=0):
@@ -1139,7 +1137,7 @@ if _enabled:
             #
             # createResolutionCallback internally adds 1 to get us to our frame, then
             # we add 1 to get to the proper surrounding scope.
-            rcb = createResolutionCallback(frames_up=1)
+            rcb = _jit_internal.createResolutionCallback(frames_up=1)
             self._define(lang, rcb, True)
 
         def copy(self):
@@ -1229,8 +1227,8 @@ def _get_weak_stubs(cls):
     stubs = []
     for name in dir(cls):
         func = get_function_from_type(cls, name)
-        if func in _weak_script_methods:
-            entry = _weak_script_methods[func]
+        if func in _jit_internal.weak_script_methods:
+            entry = _jit_internal.weak_script_methods[func]
             stub = script_method(entry["original_method"], entry["rcb"])
             stubs.append(stub)
     return stubs
@@ -1240,21 +1238,21 @@ def _make_strong(mod):
     """
     Converts a weak module into a subclass of ScriptModule
     """
-    if mod in _weak_modules:
-        return _weak_modules[mod]
+    if mod in _jit_internal.weak_modules:
+        return _jit_internal.weak_modules[mod]
 
-    stubs = _weak_types.get(type(mod))["method_stubs"]
+    stubs = _jit_internal.weak_types.get(type(mod))["method_stubs"]
 
     if stubs is None:
-        # Generate stubs and and store on _weak_types in case this type is
+        # Generate stubs and and store on weak_types in case this type is
         # used again
         stubs = _get_weak_stubs(type(mod))
-        _weak_types[type(mod)]["method_stubs"] = stubs
+        _jit_internal.weak_types[type(mod)]["method_stubs"] = stubs
 
     # Create proxy with stubs
     proxy = WeakScriptModuleProxy(mod, stubs)
 
-    _weak_modules[mod] = proxy
+    _jit_internal.weak_modules[mod] = proxy
 
     return proxy