change with_callable_args to return a fresh _PartialWrapper (#63374)
authorTaylor Robie <taylorrobie@fb.com>
Tue, 17 Aug 2021 16:09:59 +0000 (09:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 16:11:38 +0000 (09:11 -0700)
commit336aa9cd85567893f581d3597bb3d471e9dd04da
tree73b4477e8d7c8a08675f6e18fe9336fdbffbea99
parent7bad9ac78ae430b60745543cffa6002afb85e415
change with_callable_args to return a fresh _PartialWrapper (#63374)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/63326

Currently `get_callable_args` has the side effect of mutating the input _PartialWrapper. When that input is one of the global defaults, there are all sorts of lifetime issues that crop up. (Details in the linked issue.) So far as I can tell, we only need to make a constructor which is module (and by extension device) aware, so making a fresh one should have the same effect without leaking the last call's module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63374

Test Plan: the repro in https://github.com/pytorch/pytorch/issues/63326 now reports no leaked Tensors, and all quantization tests pass locally.

Reviewed By: HDCharles

Differential Revision: D30359360

Pulled By: robieta

fbshipit-source-id: aef33261ac49952d8d90da868a57ab063dfc456e
torch/quantization/observer.py