Make saved tensors default hooks thread local (#62909)
authorVictor Quach <quach@fb.com>
Fri, 13 Aug 2021 14:47:12 +0000 (07:47 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 14:49:20 +0000 (07:49 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62909

This PR makes saved tensors default hooks thread local.
This allows using default hooks in a multithreaded context.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D30165416

Pulled By: Varal7

fbshipit-source-id: 10a7d580661d3d94bdaf398c4e076b7bea11c16b

aten/src/ATen/SavedTensorHooks.cpp [new file with mode: 0644]
aten/src/ATen/SavedTensorHooks.h [new file with mode: 0644]
aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
test/test_autograd.py
tools/build_variables.bzl
torch/csrc/autograd/python_saved_variable_hooks.cpp
torch/csrc/autograd/python_saved_variable_hooks.h

diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp
new file mode 100644 (file)
index 0000000..62d762b
--- /dev/null
@@ -0,0 +1,39 @@
+#include <ATen/SavedTensorHooks.h>
+#include <c10/util/Exception.h>
+
+namespace at {
+
+namespace {
+  // PyObject is defined in c10/util/python_stub.h
+  // Reference counting is handled by the caller of `set_hooks`.
+  thread_local PyObject* pack_hook_(nullptr);
+  thread_local PyObject* unpack_hook_(nullptr);
+
+  // This flag is set to true the first time default hooks are registered
+  // and left at true for the rest of the execution.
+  // It's an optimization so that users who never use default hooks don't need to
+  // read the thread_local variables pack_hook_ and unpack_hook_.
+  static bool is_enabled(false);
+}
+
+void SavedTensorDefaultHooks::enable() {
+  is_enabled = true;
+}
+
+void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
+  if (!is_enabled) {
+    TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr);
+    return;
+  }
+  pack_hook_ = pack_hook;
+  unpack_hook_ = unpack_hook;
+}
+
+std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
+  if (!is_enabled) {
+    return std::make_pair(nullptr, nullptr);
+  }
+  return std::make_pair(pack_hook_, unpack_hook_);
+}
+
+}
diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h
new file mode 100644 (file)
index 0000000..0f3be92
--- /dev/null
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <c10/macros/Export.h>
+#include <c10/util/python_stub.h>
+
+#include <utility>
+
+namespace at {
+
+struct TORCH_API SavedTensorDefaultHooks {
+  static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook);
+  static std::pair<PyObject*, PyObject*> get_hooks();
+  static void enable();
+};
+
+} // namespace at
index f5319e3..ba7be1a 100644 (file)
@@ -5,6 +5,7 @@
 #endif
 
 #include <ATen/record_function.h>
+#include <ATen/SavedTensorHooks.h>
 
 namespace at {
 
@@ -13,6 +14,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
       debug_info_(c10::ThreadLocalDebugInfo::current()),
       inference_mode_enabled_(c10::InferenceMode::is_enabled()) {
   rf_tls_ = at::get_record_function_tls_();
+  saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
 
 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
   keep_grad_mode_ = keep_grad_mode;
@@ -34,6 +36,10 @@ void ThreadLocalState::setThreadLocalState(
 
   at::set_record_function_tls_(state.rf_tls_);
 
+  SavedTensorDefaultHooks::set_hooks(
+      state.saved_tensors_default_hooks_.first,
+      state.saved_tensors_default_hooks_.second);
+
   c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
 
   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
index 89afa3e..f30f5e3 100644 (file)
@@ -43,6 +43,9 @@ class TORCH_API ThreadLocalState {
   // TLS for InferenceMode
   bool inference_mode_enabled_;
 
+  // TLS for saved tensors default hooks
+  std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
+
   // Whether pre-sampling RecordFunction optimization was enabled
   bool bumped_record_all_functions_ = false;
 
index 6cf7277..2b7db29 100644 (file)
@@ -9341,6 +9341,50 @@ class TestMultithreadAutograd(TestCase):
         # be accumulate to the same place and should be the same
         self._run_py_multithread_fn(train_fn_grad, (x,))
 
+    def test_multithread_saved_tensors_hooks(self):
+        def pack(x):
+            warnings.warn("pack")
+            return x
+
+        def registers_hooks_for_each_thread():
+            with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
+                x = torch.ones(5, 5, requires_grad=True)
+                with warnings.catch_warnings(record=True) as w:
+                    y = x * x
+                    # should raise two warnings from x being saved twice
+                    self.assertEqual(len(w), 2)
+            y.sum().backward()
+
+    def test_dataparallel_saved_tensors_hooks(self):
+        def pack(x):
+            warnings.warn("pack")
+            return x
+
+        _self = self
+
+        class Model(torch.nn.Module):
+            def forward(self, x):
+                with warnings.catch_warnings(record=True) as w:
+                    y = x * x
+                    if torch.cuda.device_count() >= 2:
+                        # DataParallel is calling the forward in different threads
+                        # without progating TLS, so hooks should not be called here
+                        _self.assertEqual(len(w), 0)
+                    else:
+                        # DataParallel only uses one thread
+                        # so hooks should be called here
+                        _self.assertEqual(len(w), 2)
+
+        x = torch.ones(5, 5, requires_grad=True)
+        model = torch.nn.DataParallel(Model())
+
+        with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
+            model(x)
+            with warnings.catch_warnings(record=True) as w:
+                y = x * x
+                # hooks should be called here
+                _self.assertEqual(len(w), 2)
+
     def test_python_thread_in_middle(self):
         # User might write a network that starts on one CPU thread, then runs its second half
         # concurrently with other threads (either via python threading or fork/join calls),
index 6f0cf61..0538f6d 100644 (file)
@@ -848,6 +848,7 @@ aten_cpu_source_non_codegen_list = [
     "aten/src/ATen/native/mkldnn/Utils.cpp",
     "aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp",
     "aten/src/ATen/record_function.cpp",
+    "aten/src/ATen/SavedTensorHooks.cpp",
     "aten/src/ATen/vulkan/Context.cpp",
 ]
 
index 6ec684d..c81a07d 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
+#include <ATen/SavedTensorHooks.h>
 
 #include <torch/csrc/THP.h>
 
@@ -45,39 +46,37 @@ namespace torch { namespace autograd {
     }
   }
 
-  std::mutex PyDefaultSavedVariableHooks::mutex_;
-  PyObject* PyDefaultSavedVariableHooks::pack_hook_(nullptr);
-  PyObject* PyDefaultSavedVariableHooks::unpack_hook_(nullptr);
-
   void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
-    std::lock_guard<std::mutex> lock(mutex_);
+    PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
+    std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
     TORCH_CHECK(!pack_hook_ && !unpack_hook_,
         "Setting default hooks but they have already been set. "
         "Hint: only one pair of hooks is allowed at a time.");
-    pack_hook_ = pack_hook.release().ptr();
-    unpack_hook_ = unpack_hook.release().ptr();
+    at::SavedTensorDefaultHooks::enable();
+    at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
   }
 
   void PyDefaultSavedVariableHooks::reset_hooks() {
-    std::lock_guard<std::mutex> lock(mutex_);
+    PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
+    std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
     if (Py_IsInitialized()) {
       py::gil_scoped_acquire gil;
-      Py_XDECREF(pack_hook_);
-      Py_XDECREF(unpack_hook_);
+      Py_XDECREF(pack_hook);
+      Py_XDECREF(unpack_hook);
     }
-    pack_hook_ = nullptr;
-    unpack_hook_ = nullptr;
+    at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
   }
 
   std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
-    if (!pack_hook_ || !unpack_hook_) {
+    PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
+    std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
+    if (!pack_hook || !unpack_hook) {
       return nullptr;
     }
-    std::lock_guard<std::mutex> lock(mutex_);
     py::gil_scoped_acquire gil;
-    py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
-    py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
-    return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook);
+    py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
+    py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
+    return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
   }
 
 }}
index ca5bac2..145301f 100644 (file)
@@ -27,14 +27,6 @@ struct PyDefaultSavedVariableHooks {
   static void set_hooks(py::function &pack_hook, py::function &unpack_hook);
   static void reset_hooks();
   static std::unique_ptr<SavedVariableHooks> get_hooks();
-
-private:
-  static PyObject* pack_hook_;
-  static PyObject* unpack_hook_;
-
-  // Mutex to ensure that concurrent operations that modify default pack_hook_ and
-  // unpack_hook_ are thread-safe.
-  static std::mutex mutex_;
 };
 
 }}