From 8610ff1072574a9b5772add8d6b63298d9ec9acb Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 16 Nov 2018 13:59:31 -0800 Subject: [PATCH] Allow cooperative structured objects to be passed modules in tracing (#13961) Summary: Before this patch, the JIT does not allow Module's forward to take structured objects. This patch allows cooperative objects to do so. Cooperative means: - It has a method self._jit_unwrap() that returns (a list/tuple of) tensors. These are then used in _iter_tensors. - It has a method self._jit_wrap(flattened_input) that takes (a list/tuple?) the flattened_unput (potentially more than it needs) and returns itself (updated) and the unconsumed flattened_inputs. This is then used in the _unflatten mechanism. This is all it takes to permit maskrcnn-benchmark to use its structured BoxList/ImageList types and trace it without calling the .forward directly. I'll push a model working with this patch in https://github.com/facebookresearch/maskrcnn-benchmark/pull/138 I must admit I haven't fully checked whether there are ONNX changes needed before it, too, can profit, but I would be hopeful that anything currently usable remains so. fmassa zdevito So the main downside that I'm aware of is that people will later want to use more elaborate mechanisms, but I think this could be done by just amending what wrap/unwrap are returning / consuming. Pull Request resolved: https://github.com/pytorch/pytorch/pull/13961 Differential Revision: D13103927 Pulled By: soumith fbshipit-source-id: 2cbc724cc4b53197388b662f75d9e601a495c087 --- torch/autograd/function.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 9fc07e5..ea48a4f 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -264,8 +264,17 @@ def _nested_map(condition, fn, condition_msg=None): return _map -def _iter_filter(condition, allow_unknown=False, condition_msg=None): +def _jit_unwrap_structured(obj): + if hasattr(obj, "_jit_unwrap"): + return obj._jit_unwrap() + return obj + + +def _iter_filter(condition, allow_unknown=False, condition_msg=None, + conversion=None): def _iter(obj): + if conversion is not None: + obj = conversion(obj) if condition(obj): yield obj elif obj is None: @@ -291,6 +300,8 @@ def _unflatten(input, proto): # specified by proto def unflatten_helper(input, proto): res = [] + if hasattr(proto, "_jit_wrap"): + return proto._jit_wrap(input) if not isinstance(proto, (list, tuple)): return input[0], input[1:] for e in proto: @@ -306,7 +317,8 @@ def _unflatten(input, proto): _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value), condition_msg="jit's Values or None") -_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors") +_iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors", + conversion=_jit_unwrap_structured) _iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor), allow_unknown=True, condition_msg="Tensors (permissive)") -- 2.7.4