Allow cooperative structured objects to be passed modules in tracing (#13961)
Summary:
Before this patch, the JIT does not allow Module's forward to take
structured objects.
This patch allows cooperative objects to do so.
Cooperative means:
- It has a method self._jit_unwrap() that returns (a list/tuple of)
tensors. These are then used in _iter_tensors.
- It has a method self._jit_wrap(flattened_input) that takes
(a list/tuple?) the flattened_unput (potentially more than it needs)
and returns itself (updated) and the unconsumed flattened_inputs.
This is then used in the _unflatten mechanism.
This is all it takes to permit maskrcnn-benchmark to use
its structured BoxList/ImageList types and trace it without calling
the .forward directly.
I'll push a model working with this patch in
https://github.com/facebookresearch/maskrcnn-benchmark/pull/138
I must admit I haven't fully checked whether there are ONNX changes needed before it, too, can profit, but I would be hopeful that anything currently usable remains so.
fmassa zdevito
So the main downside that I'm aware of is that people will later want to use more elaborate mechanisms, but I think this could be done by just amending what wrap/unwrap are returning / consuming.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13961
Differential Revision:
D13103927
Pulled By: soumith
fbshipit-source-id:
2cbc724cc4b53197388b662f75d9e601a495c087