Add metadata for torch jit TracedModules. (#17640)
authorPritam Damania <pritam.damania@fb.com>
Sun, 10 Mar 2019 05:31:42 +0000 (21:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 10 Mar 2019 05:37:15 +0000 (21:37 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17640

Pull Request resolved: https://github.com/pytorch/pytorch/pull/17311

I've extended our model metadata framework in this diff to support
traced modules as well. Re-used a lot of components from the previous
implementation of ScriptModule metadata.

Tracing is a little different from Scripting since you can't just create a
subclass of TopLevelTraceModule (type returned by torch.jit.trace) and attach
metadata the way we did for ScriptModule. As a result, I've introduced a
separate API torch.fb.jit_trace which returns an instance of
TracedModuleWithMetadata which is a subclass of TopLevelTracedModule. As a
result, we can now attach metadata to this instance.

Reviewed By: dzhulgakov

Differential Revision: D14117966

fbshipit-source-id: 3eee5eef733cb8d6a219c02e2f41d08698eca326

torch/jit/__init__.py

index fd1cfec..62c3bc5 100644 (file)
@@ -591,7 +591,8 @@ def trace(func,
           check_trace=True,
           check_inputs=None,
           check_tolerance=1e-5,
-          _force_outplace=False):
+          _force_outplace=False,
+          _module_class=None):
     """
     Trace a function and return an executable trace that will be optimized
     using just-in-time compilation.
@@ -657,7 +658,10 @@ def trace(func,
     # done primarily so that weird iterables fail here and not pybind11 code
     elif not isinstance(example_inputs, tuple):
         example_inputs = tuple(example_inputs)
-    module = TopLevelTracedModule(func, **executor_options)
+    if _module_class:
+        module = _module_class(func, **executor_options)
+    else:
+        module = TopLevelTracedModule(func, **executor_options)
     var_lookup_fn = _create_interpreter_name_lookup_fn(0)
     module._create_method_from_trace('forward', func, example_inputs,
                                      var_lookup_fn, _force_outplace)