Add script standard library documentation + cleanup (#14912)
authorDavid Riazati <davidriazati@fb.com>
Wed, 12 Dec 2018 20:25:40 +0000 (12:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 20:30:13 +0000 (12:30 -0800)
Summary:
Documents what is supported in the script standard library.

* Adds `my_script_module._get_method('forward').schema()` method to get function schema from a `ScriptModule`
* Removes `torch.nn.functional` from the list of builtins. The only functions not supported are `nn.functional.fold` and `nn.functional.unfold`, but those currently just dispatch to their corresponding aten ops, so from a user's perspective it looks like they work.
* Allow printing of `IValue::Device` by getting its string representation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14912

Differential Revision: D13385928

Pulled By: driazati

fbshipit-source-id: e391691b2f87dba6e13be05d4aa3ed2f004e31da

docs/source/jit.rst
torch/csrc/jit/pybind_utils.h
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py
torch/jit/supported_ops.py

index 3511a44..70c79fa 100644 (file)
@@ -787,14 +787,27 @@ Tracer Warnings
 Builtin Functions
 ~~~~~~~~~~~~~~~~~
 
-TorchScript supports a subset of the builtin tensor and neural network functions that
-PyTorch provides. Most methods on Tensor as well as functions in the ``torch``
-namespace are available. Many functions in ``torch.nn.functional`` are also availiable.
+Torch Script supports a subset of the builtin tensor and neural network
+functions that PyTorch provides. Most methods on Tensor as well as functions in
+the ``torch`` namespace, all functions in ``torch.nn.functional`` and all
+modules from ``torch.nn`` are supported in Torch Script, excluding those in the
+table below. For unsupported modules, we suggest using :meth:`torch.jit.trace`.
+
+=====
+Unsupported ``torch.nn`` Modules
+=====
+``torch.nn.modules.adaptive.AdaptiveLogSoftmaxWithLoss``
+``torch.nn.modules.normalization.CrossMapLRN2d``
+``torch.nn.modules.fold.Fold``
+``torch.nn.modules.fold.Unfold``
+``torch.nn.modules.rnn.GRU``
+``torch.nn.modules.rnn.LSTM``
+``torch.nn.modules.rnn.RNN``
+``torch.nn.modules.rnn.GRUCell``
+``torch.nn.modules.rnn.LSTMCell``
+``torch.nn.modules.rnn.RNNCell``
+=====
 
 
-We currently do not provide any builtin ScriptModules e.g. a ``Linear`` or
-``Conv`` module. This functionality is something that will be developed in the future.
-For now we suggest using ``torch.jit.trace`` to transform standard ``torch.nn``
-modules into ScriptModules on construction.
 
 .. automodule:: torch.jit.supported_ops
index 99a76de..560d944 100644 (file)
@@ -287,6 +287,8 @@ inline py::object toPyObject(IValue&& ivalue) {
       t[i] = toPyObject(IValue{elements[i]});
     }
     return t;
+  } else if (ivalue.isDevice()) {
+    return py::cast<py::object>(THPDevice_New(ivalue.toDevice()));
   } else {
     AT_ERROR("Missing cases in 'toPyObject'! File a bug report.");
   }
index cf0188d..1780c1f 100644 (file)
@@ -709,6 +709,7 @@ void initJitScriptBindings(PyObject* module) {
       return self.graph_for(createStackForSchema(self.getSchema(), tuple_slice(std::move(args), 1), std::move(kwargs)));
     })
     .def("debug_disable_autodiff_subgraph_inlining", &Method::debugDisableAutodiffSubgraphInlining)
+    .def("schema", &Method::getSchema)
     .def("pretty_print_schema", &Method::pretty_print_schema)
     .def("python_print", [](Method &m) {
       std::ostringstream oss;
index 77906e4..c9c129e 100644 (file)
@@ -1379,22 +1379,7 @@ class _ConstSequential(_ConstModuleList):
 
 _builtin_table = None
 
-_modules_containing_builtins = (torch, torch.nn.functional, torch._C._nn)
-
-# These functions have been converted to weak script, so don't add them as
-# builtin aten ops. Instead, they will be compiled from the code in
-# torch.nn.functional when used.
-
-
-# TODO: delete _should_skip() and remove torch.nn.functional from builtins list
-# once everything in it has been converted to weak script
-def _should_skip(mod, name):
-    if mod is not torch.nn.functional:
-        return False
-    func = getattr(torch.nn.functional, name)
-    if func is None:
-        return False
-    return func in _compiled_weak_fns or func in _boolean_dispatched
+_modules_containing_builtins = (torch, torch._C._nn)
 
 
 def _unwrap_optional(x):
@@ -1412,7 +1397,7 @@ def _get_builtin_table():
     def register_all(mod):
         for name in dir(mod):
             v = getattr(mod, name)
-            if callable(v) and not _should_skip(mod, name):
+            if callable(v):
                 _builtin_table[id(v)] = "aten::" + name
     for mod in _modules_containing_builtins:
         register_all(mod)
@@ -1433,6 +1418,8 @@ def _get_builtin_table():
     _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
     _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
     _builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear"
+    _builtin_table[id(torch.nn.functional.fold)] = "aten::fold"
+    _builtin_table[id(torch.nn.functional.unfold)] = "aten::unfold"
 
     return _builtin_table
 
index fbb5bf5..6b629dd 100644 (file)
@@ -51,6 +51,22 @@ def _list_supported_ops():
                     if not hidden(elem):
                         functions.append(emit_schema(name, elem, schema))
 
+    mod = torch.nn.functional
+    name = mod.__name__
+    for elem in dir(torch.nn.functional):
+        # weak script functions
+        attr = getattr(mod, elem)
+        if not callable(attr) or elem[0] == '_':
+            # ignore non-functions and internal methods
+            continue
+
+        # compile weak script fn, get schema
+        scripted = torch.jit._try_compile_weak_script(attr)
+        if scripted is None:
+            continue
+        schema = scripted._get_method('forward').schema()
+        functions.append(emit_schema(name, elem, schema))
+
     def is_tensor_method(schema):
         if len(schema.arguments) == 0:
             return False