from torch._six import builtins
# Tracks standalone weak script functions
-_compiled_weak_fns = weakref.WeakKeyDictionary()
+compiled_weak_fns = weakref.WeakKeyDictionary()
# Tracks which methods should be converted to strong methods
-_weak_script_methods = weakref.WeakKeyDictionary()
+weak_script_methods = weakref.WeakKeyDictionary()
# Converted modules and their corresponding WeakScriptModuleProxy objects
-_weak_modules = weakref.WeakKeyDictionary()
+weak_modules = weakref.WeakKeyDictionary()
# Types that have been declared as weak modules
-_weak_types = weakref.WeakKeyDictionary()
+weak_types = weakref.WeakKeyDictionary()
# Wrapper functions that can call either of 2 functions depending on a boolean
# argument
-_boolean_dispatched = weakref.WeakKeyDictionary()
+boolean_dispatched = weakref.WeakKeyDictionary()
COMPILATION_PENDING = object()
COMPILED = object()
inlined in the graph. When not used in a script function, the weak script
annotation has no effect.
"""
- _compiled_weak_fns[fn] = {
+ compiled_weak_fns[fn] = {
"status": COMPILATION_PENDING,
"compiled_fn": None,
"rcb": createResolutionCallback(_frames_up + 1)
def weak_module(cls):
- _weak_types[cls] = {
+ weak_types[cls] = {
"method_stubs": None
}
return cls
def weak_script_method(fn):
- _weak_script_methods[fn] = {
+ weak_script_methods[fn] = {
"rcb": createResolutionCallback(frames_up=2),
"original_method": fn
}
In TorchScript, the boolean argument must be constant so that the correct
function to use can be determined at compile time.
"""
- if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None:
+ if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
raise RuntimeError("both functions must be weak script")
def fn(*args, **kwargs):
raise RuntimeError("only one function can have a docstring")
fn.__doc__ = doc
- _boolean_dispatched[fn] = {
+ boolean_dispatched[fn] = {
"if_true": if_true,
"if_false": if_false,
"index": arg_index,
from torch.jit.frontend import get_jit_ast, get_default_args
import torch.backends.cudnn as cudnn
import torch.jit.annotations
+import torch._jit_internal as _jit_internal
from torch._six import raise_from, with_metaclass, get_function_from_type, \
string_classes
-from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
- _weak_script_methods, _weak_modules, _weak_types, COMPILED, \
- COMPILATION_PENDING, _boolean_dispatched
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default
import torch.testing
def define(self, lang, rcb=None, _frames_up=0):
if not rcb:
- rcb = createResolutionCallback(_frames_up + 1)
+ rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
self.module._define(lang, rcb, False)
def __getattr__(self, attr):
def _try_get_dispatched_fn(fn):
if not callable(fn):
return None
- return _boolean_dispatched.get(fn)
+ return _jit_internal.boolean_dispatched.get(fn)
def _try_compile_weak_script(fn):
- entry = _compiled_weak_fns.get(fn)
+ entry = _jit_internal.compiled_weak_fns.get(fn)
if entry is None:
return None
- if entry["status"] == COMPILATION_PENDING:
+ if entry["status"] == _jit_internal.COMPILATION_PENDING:
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
del entry["rcb"]
- _compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
- entry["status"] = COMPILED
+ _jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
+ entry["status"] = _jit_internal.COMPILED
return compiled_fn
else:
return entry["compiled_fn"]
if not _enabled:
return fn
if _rcb is None:
- _rcb = createResolutionCallback(_frames_up + 1)
+ _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
ast = get_jit_ast(fn, is_method=False)
mod = ScriptModule()
_jit_script_compile(mod, ast, _rcb, get_default_args(fn))
# createResolutionCallback internally adds 1 to get us to the scope of this
# function (the calling function). Adding 2 gets us to the proper surrounding scope.
if _rcb is None:
- _rcb = createResolutionCallback(frames_up=2)
+ _rcb = _jit_internal.createResolutionCallback(frames_up=2)
ast = get_jit_ast(fn, is_method=True)
return ScriptMethodStub(_rcb, ast, fn)
"""
if not isinstance(mod, Module):
return None
- return _weak_modules.get(mod)
+ return _jit_internal.weak_modules.get(mod)
def _is_weak_type(cls):
"""
Check if a type has been annotated with `weak_module`
"""
- return cls in _weak_types
+ return cls in _jit_internal.weak_types
def batch(batch_size=1, optimize=True, _frames_up=0):
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
- rcb = createResolutionCallback(frames_up=1)
+ rcb = _jit_internal.createResolutionCallback(frames_up=1)
self._define(lang, rcb, True)
def copy(self):
stubs = []
for name in dir(cls):
func = get_function_from_type(cls, name)
- if func in _weak_script_methods:
- entry = _weak_script_methods[func]
+ if func in _jit_internal.weak_script_methods:
+ entry = _jit_internal.weak_script_methods[func]
stub = script_method(entry["original_method"], entry["rcb"])
stubs.append(stub)
return stubs
"""
Converts a weak module into a subclass of ScriptModule
"""
- if mod in _weak_modules:
- return _weak_modules[mod]
+ if mod in _jit_internal.weak_modules:
+ return _jit_internal.weak_modules[mod]
- stubs = _weak_types.get(type(mod))["method_stubs"]
+ stubs = _jit_internal.weak_types.get(type(mod))["method_stubs"]
if stubs is None:
- # Generate stubs and and store on _weak_types in case this type is
+ # Generate stubs and and store on weak_types in case this type is
# used again
stubs = _get_weak_stubs(type(mod))
- _weak_types[type(mod)]["method_stubs"] = stubs
+ _jit_internal.weak_types[type(mod)]["method_stubs"] = stubs
# Create proxy with stubs
proxy = WeakScriptModuleProxy(mod, stubs)
- _weak_modules[mod] = proxy
+ _jit_internal.weak_modules[mod] = proxy
return proxy