Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17640
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17311
I've extended our model metadata framework in this diff to support
traced modules as well. Re-used a lot of components from the previous
implementation of ScriptModule metadata.
Tracing is a little different from Scripting since you can't just create a
subclass of TopLevelTraceModule (type returned by torch.jit.trace) and attach
metadata the way we did for ScriptModule. As a result, I've introduced a
separate API torch.fb.jit_trace which returns an instance of
TracedModuleWithMetadata which is a subclass of TopLevelTracedModule. As a
result, we can now attach metadata to this instance.
Reviewed By: dzhulgakov
Differential Revision:
D14117966
fbshipit-source-id:
3eee5eef733cb8d6a219c02e2f41d08698eca326
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
- _force_outplace=False):
+ _force_outplace=False,
+ _module_class=None):
"""
Trace a function and return an executable trace that will be optimized
using just-in-time compilation.
# done primarily so that weird iterables fail here and not pybind11 code
elif not isinstance(example_inputs, tuple):
example_inputs = tuple(example_inputs)
- module = TopLevelTracedModule(func, **executor_options)
+ if _module_class:
+ module = _module_class(func, **executor_options)
+ else:
+ module = TopLevelTracedModule(func, **executor_options)
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
module._create_method_from_trace('forward', func, example_inputs,
var_lookup_fn, _force_outplace)