change with_callable_args to return a fresh _PartialWrapper (#63374)
authorTaylor Robie <taylorrobie@fb.com>
Tue, 17 Aug 2021 16:09:59 +0000 (09:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 16:11:38 +0000 (09:11 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63326

Currently `get_callable_args` has the side effect of mutating the input _PartialWrapper. When that input is one of the global defaults, there are all sorts of lifetime issues that crop up. (Details in the linked issue.) So far as I can tell, we only need to make a constructor which is module (and by extension device) aware, so making a fresh one should have the same effect without leaking the last call's module.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63374

Test Plan: the repro in https://github.com/pytorch/pytorch/issues/63326 now reports no leaked Tensors, and all quantization tests pass locally.

Reviewed By: HDCharles

Differential Revision: D30359360

Pulled By: robieta

fbshipit-source-id: aef33261ac49952d8d90da868a57ab063dfc456e

torch/quantization/observer.py

index 8946c7d..fc6af82 100644 (file)
@@ -30,8 +30,9 @@ class _PartialWrapper(object):
         return _with_args(self, **kwargs)
 
     def with_callable_args(self, **kwargs):
-        self.callable_args = {**self.callable_args, **kwargs}
-        return self
+        result = _PartialWrapper(p=self.p)
+        result.callable_args = {**self.callable_args, **kwargs}
+        return result
 
 
 def _with_args(cls_or_self, **kwargs):