From: Victor Quach Date: Fri, 13 Aug 2021 14:47:12 +0000 (-0700) Subject: Make saved tensors default hooks thread local (#62909) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1048 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5abeac3ef70a5ff3035f5cd11a217b5d15ebcd49;p=platform%2Fupstream%2Fpytorch.git Make saved tensors default hooks thread local (#62909) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62909 This PR makes saved tensors default hooks thread local. This allows using default hooks in a multithreaded context. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30165416 Pulled By: Varal7 fbshipit-source-id: 10a7d580661d3d94bdaf398c4e076b7bea11c16b --- diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp new file mode 100644 index 0000000..62d762b --- /dev/null +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -0,0 +1,39 @@ +#include +#include + +namespace at { + +namespace { + // PyObject is defined in c10/util/python_stub.h + // Reference counting is handled by the caller of `set_hooks`. + thread_local PyObject* pack_hook_(nullptr); + thread_local PyObject* unpack_hook_(nullptr); + + // This flag is set to true the first time default hooks are registered + // and left at true for the rest of the execution. + // It's an optimization so that users who never use default hooks don't need to + // read the thread_local variables pack_hook_ and unpack_hook_. + static bool is_enabled(false); +} + +void SavedTensorDefaultHooks::enable() { + is_enabled = true; +} + +void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) { + if (!is_enabled) { + TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr); + return; + } + pack_hook_ = pack_hook; + unpack_hook_ = unpack_hook; +} + +std::pair SavedTensorDefaultHooks::get_hooks() { + if (!is_enabled) { + return std::make_pair(nullptr, nullptr); + } + return std::make_pair(pack_hook_, unpack_hook_); +} + +} diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h new file mode 100644 index 0000000..0f3be92 --- /dev/null +++ b/aten/src/ATen/SavedTensorHooks.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include + +namespace at { + +struct TORCH_API SavedTensorDefaultHooks { + static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook); + static std::pair get_hooks(); + static void enable(); +}; + +} // namespace at diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index f5319e3..ba7be1a 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -5,6 +5,7 @@ #endif #include +#include namespace at { @@ -13,6 +14,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode) debug_info_(c10::ThreadLocalDebugInfo::current()), inference_mode_enabled_(c10::InferenceMode::is_enabled()) { rf_tls_ = at::get_record_function_tls_(); + saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) keep_grad_mode_ = keep_grad_mode; @@ -34,6 +36,10 @@ void ThreadLocalState::setThreadLocalState( at::set_record_function_tls_(state.rf_tls_); + SavedTensorDefaultHooks::set_hooks( + state.saved_tensors_default_hooks_.first, + state.saved_tensors_default_hooks_.second); + c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_); c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_); diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 89afa3e..f30f5e3 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -43,6 +43,9 @@ class TORCH_API ThreadLocalState { // TLS for InferenceMode bool inference_mode_enabled_; + // TLS for saved tensors default hooks + std::pair saved_tensors_default_hooks_; + // Whether pre-sampling RecordFunction optimization was enabled bool bumped_record_all_functions_ = false; diff --git a/test/test_autograd.py b/test/test_autograd.py index 6cf7277..2b7db29 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -9341,6 +9341,50 @@ class TestMultithreadAutograd(TestCase): # be accumulate to the same place and should be the same self._run_py_multithread_fn(train_fn_grad, (x,)) + def test_multithread_saved_tensors_hooks(self): + def pack(x): + warnings.warn("pack") + return x + + def registers_hooks_for_each_thread(): + with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): + x = torch.ones(5, 5, requires_grad=True) + with warnings.catch_warnings(record=True) as w: + y = x * x + # should raise two warnings from x being saved twice + self.assertEqual(len(w), 2) + y.sum().backward() + + def test_dataparallel_saved_tensors_hooks(self): + def pack(x): + warnings.warn("pack") + return x + + _self = self + + class Model(torch.nn.Module): + def forward(self, x): + with warnings.catch_warnings(record=True) as w: + y = x * x + if torch.cuda.device_count() >= 2: + # DataParallel is calling the forward in different threads + # without progating TLS, so hooks should not be called here + _self.assertEqual(len(w), 0) + else: + # DataParallel only uses one thread + # so hooks should be called here + _self.assertEqual(len(w), 2) + + x = torch.ones(5, 5, requires_grad=True) + model = torch.nn.DataParallel(Model()) + + with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): + model(x) + with warnings.catch_warnings(record=True) as w: + y = x * x + # hooks should be called here + _self.assertEqual(len(w), 2) + def test_python_thread_in_middle(self): # User might write a network that starts on one CPU thread, then runs its second half # concurrently with other threads (either via python threading or fork/join calls), diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 6f0cf61..0538f6d 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -848,6 +848,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/native/mkldnn/Utils.cpp", "aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp", "aten/src/ATen/record_function.cpp", + "aten/src/ATen/SavedTensorHooks.cpp", "aten/src/ATen/vulkan/Context.cpp", ] diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp index 6ec684d..c81a07d 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.cpp +++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -45,39 +46,37 @@ namespace torch { namespace autograd { } } - std::mutex PyDefaultSavedVariableHooks::mutex_; - PyObject* PyDefaultSavedVariableHooks::pack_hook_(nullptr); - PyObject* PyDefaultSavedVariableHooks::unpack_hook_(nullptr); - void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) { - std::lock_guard lock(mutex_); + PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr); + std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks(); TORCH_CHECK(!pack_hook_ && !unpack_hook_, "Setting default hooks but they have already been set. " "Hint: only one pair of hooks is allowed at a time."); - pack_hook_ = pack_hook.release().ptr(); - unpack_hook_ = unpack_hook.release().ptr(); + at::SavedTensorDefaultHooks::enable(); + at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr()); } void PyDefaultSavedVariableHooks::reset_hooks() { - std::lock_guard lock(mutex_); + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); if (Py_IsInitialized()) { py::gil_scoped_acquire gil; - Py_XDECREF(pack_hook_); - Py_XDECREF(unpack_hook_); + Py_XDECREF(pack_hook); + Py_XDECREF(unpack_hook); } - pack_hook_ = nullptr; - unpack_hook_ = nullptr; + at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr); } std::unique_ptr PyDefaultSavedVariableHooks::get_hooks() { - if (!pack_hook_ || !unpack_hook_) { + PyObject *pack_hook(nullptr), *unpack_hook(nullptr); + std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); + if (!pack_hook || !unpack_hook) { return nullptr; } - std::lock_guard lock(mutex_); py::gil_scoped_acquire gil; - py::function pack_hook = py::reinterpret_borrow(pack_hook_); - py::function unpack_hook = py::reinterpret_borrow(unpack_hook_); - return std::make_unique(pack_hook, unpack_hook); + py::function pack_hook_ = py::reinterpret_borrow(pack_hook); + py::function unpack_hook_ = py::reinterpret_borrow(unpack_hook); + return std::make_unique(pack_hook_, unpack_hook_); } }} diff --git a/torch/csrc/autograd/python_saved_variable_hooks.h b/torch/csrc/autograd/python_saved_variable_hooks.h index ca5bac2..145301f 100644 --- a/torch/csrc/autograd/python_saved_variable_hooks.h +++ b/torch/csrc/autograd/python_saved_variable_hooks.h @@ -27,14 +27,6 @@ struct PyDefaultSavedVariableHooks { static void set_hooks(py::function &pack_hook, py::function &unpack_hook); static void reset_hooks(); static std::unique_ptr get_hooks(); - -private: - static PyObject* pack_hook_; - static PyObject* unpack_hook_; - - // Mutex to ensure that concurrent operations that modify default pack_hook_ and - // unpack_hook_ are thread-safe. - static std::mutex mutex_; }; }}