From ed7ece389daa980a03a2f09848a6439a90af9782 Mon Sep 17 00:00:00 2001 From: Victor Quach Date: Thu, 12 Aug 2021 12:36:38 -0700 Subject: [PATCH] Forbid inplace modification of a saved tensor's pack_hook input (#62717) Summary: When using saved tensors hooks (especially default hooks), if the user defines a `pack_hook` that modifies its input, it can cause some surprising behavior. The goal of this PR is to prevent future user headache by catching inplace modifications of the input of `pack_hook` and raising an error if applicable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/62717 Reviewed By: albanD Differential Revision: D30255243 Pulled By: Varal7 fbshipit-source-id: 8d73f1e1b50b697a59a2849b5e21cf0aa7493b76 --- test/test_autograd.py | 22 ++++++++++++++++------ torch/_C/_autograd.pyi | 4 ++-- torch/autograd/__init__.py | 2 +- torch/autograd/graph.py | 8 ++++---- torch/csrc/autograd/init.cpp | 4 ++-- torch/csrc/autograd/saved_variable.cpp | 9 ++++++++- 6 files changed, 33 insertions(+), 16 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index cacc156..6cf7277 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5753,12 +5753,8 @@ for shape in [(1,), ()]: a = get_input() t = a * a - t.grad_fn._raw_saved_self.register_hooks(inplace_double, lambda x: x / 2) - y = t * 2 - with self.assertRaisesRegex( - RuntimeError, - "one of the variables needed for gradient computation has been modified by an inplace operation"): - y.sum().backward() + with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."): + t.grad_fn._raw_saved_self.register_hooks(inplace_double, lambda x: x / 2) # leaf test(lambda: torch.randn(5, requires_grad=True), True) @@ -5847,6 +5843,20 @@ for shape in [(1,), ()]: with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): pass + def test_pack_hook_with_inplace_modification_should_fail(self): + a = torch.randn(5, requires_grad=True) + + def inc(x): + x += 1 + return x + with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x): + with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."): + y = torch.exp(a) + + y = torch.exp(a) + with self.assertRaisesRegex(RuntimeError, "A saved tensor pack hook is modifying its input in place."): + y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x) + def test_saving_variable_to_disk(self): with tempfile.TemporaryDirectory() as tmp_dir: def pack(x): diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index ef5c733..cd9b0da 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -88,8 +88,8 @@ def kineto_available() -> bool: ... def _supported_activities() -> Set[ProfilerActivity]: ... def _enable_record_function(enable: bool) -> None: ... def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... -def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... -def _reset_default_hooks() -> None: ... +def _register_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... +def _reset_saved_tensors_default_hooks() -> None: ... def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 6eb5d53..0d4f153 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -264,7 +264,7 @@ from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, Pro _enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled, _enable_record_function, _set_empty_test_observer, kineto_available, _supported_activities, _add_metadata_json, SavedTensor, - _register_default_hooks, _reset_default_hooks) + _register_saved_tensors_default_hooks, _reset_saved_tensors_default_hooks) from torch._C._autograd import (_ProfilerResult, _KinetoEvent, _prepare_profiler, _enable_profiler, _disable_profiler) diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index a0a3d9f..b0accf0 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -66,10 +66,10 @@ class saved_tensors_hooks(): self.unpack_hook = unpack_hook def __enter__(self): - torch._C._autograd._register_default_hooks(self.pack_hook, self.unpack_hook) + torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) def __exit__(self, *args: Any): - torch._C._autograd._reset_default_hooks() + torch._C._autograd._reset_saved_tensors_default_hooks() class save_on_cpu(): @@ -133,7 +133,7 @@ class save_on_cpu(): self.unpack_hook = unpack_from_cpu def __enter__(self): - torch._C._autograd._register_default_hooks(self.pack_hook, self.unpack_hook) + torch._C._autograd._register_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) def __exit__(self, *args: Any): - torch._C._autograd._reset_default_hooks() + torch._C._autograd._reset_saved_tensors_default_hooks() diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 8386fc6..ffe7e83 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -272,10 +272,10 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { m.def("_clear_callbacks", []() { at::clearCallbacks(); }); - m.def("_register_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) { + m.def("_register_saved_tensors_default_hooks", [](py::function &pack_hook, py::function &unpack_hook) { torch::autograd::PyDefaultSavedVariableHooks::set_hooks(pack_hook, unpack_hook); }); - m.def("_reset_default_hooks", []() { + m.def("_reset_saved_tensors_default_hooks", []() { torch::autograd::PyDefaultSavedVariableHooks::reset_hooks(); }); diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index e9bab12..b16fa87 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -214,7 +214,14 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { void SavedVariable::set_hooks_and_pack_data(std::unique_ptr&& hooks, const Variable& data) { hooks_ = std::move(hooks); at::NoGradGuard guard; - hooks_->call_pack_hook(saved_original_ ? data.tensor_data() : data); + const auto version = impl::version_counter(data).current_version(); + hooks_->call_pack_hook(saved_original_ ? data.detach() : data); + TORCH_CHECK(version == impl::version_counter(data).current_version(), + "A saved tensor pack hook is modifying its input in place. " + "Tensors provided as input to pack hook can not be modified by " + "in-place operations as this can lead to unexpected side-effects. " + "Please open an issue if you need to perform in-place operations on " + "the input to a pack hook."); } void SavedVariable::register_hooks(std::unique_ptr&& hooks) { -- 2.7.4