--- /dev/null
+#include <ATen/SavedTensorHooks.h>
+#include <c10/util/Exception.h>
+
+namespace at {
+
+namespace {
+ // PyObject is defined in c10/util/python_stub.h
+ // Reference counting is handled by the caller of `set_hooks`.
+ thread_local PyObject* pack_hook_(nullptr);
+ thread_local PyObject* unpack_hook_(nullptr);
+
+ // This flag is set to true the first time default hooks are registered
+ // and left at true for the rest of the execution.
+ // It's an optimization so that users who never use default hooks don't need to
+ // read the thread_local variables pack_hook_ and unpack_hook_.
+ static bool is_enabled(false);
+}
+
+void SavedTensorDefaultHooks::enable() {
+ is_enabled = true;
+}
+
+void SavedTensorDefaultHooks::set_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
+ if (!is_enabled) {
+ TORCH_INTERNAL_ASSERT(pack_hook == nullptr && unpack_hook == nullptr);
+ return;
+ }
+ pack_hook_ = pack_hook;
+ unpack_hook_ = unpack_hook;
+}
+
+std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
+ if (!is_enabled) {
+ return std::make_pair(nullptr, nullptr);
+ }
+ return std::make_pair(pack_hook_, unpack_hook_);
+}
+
+}
--- /dev/null
+#pragma once
+
+#include <c10/macros/Export.h>
+#include <c10/util/python_stub.h>
+
+#include <utility>
+
+namespace at {
+
+struct TORCH_API SavedTensorDefaultHooks {
+ static void set_hooks(PyObject* pack_hook, PyObject* unpack_hook);
+ static std::pair<PyObject*, PyObject*> get_hooks();
+ static void enable();
+};
+
+} // namespace at
#endif
#include <ATen/record_function.h>
+#include <ATen/SavedTensorHooks.h>
namespace at {
debug_info_(c10::ThreadLocalDebugInfo::current()),
inference_mode_enabled_(c10::InferenceMode::is_enabled()) {
rf_tls_ = at::get_record_function_tls_();
+ saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
keep_grad_mode_ = keep_grad_mode;
at::set_record_function_tls_(state.rf_tls_);
+ SavedTensorDefaultHooks::set_hooks(
+ state.saved_tensors_default_hooks_.first,
+ state.saved_tensors_default_hooks_.second);
+
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
// TLS for InferenceMode
bool inference_mode_enabled_;
+ // TLS for saved tensors default hooks
+ std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
+
// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;
# be accumulate to the same place and should be the same
self._run_py_multithread_fn(train_fn_grad, (x,))
+ def test_multithread_saved_tensors_hooks(self):
+ def pack(x):
+ warnings.warn("pack")
+ return x
+
+ def registers_hooks_for_each_thread():
+ with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
+ x = torch.ones(5, 5, requires_grad=True)
+ with warnings.catch_warnings(record=True) as w:
+ y = x * x
+ # should raise two warnings from x being saved twice
+ self.assertEqual(len(w), 2)
+ y.sum().backward()
+
+ def test_dataparallel_saved_tensors_hooks(self):
+ def pack(x):
+ warnings.warn("pack")
+ return x
+
+ _self = self
+
+ class Model(torch.nn.Module):
+ def forward(self, x):
+ with warnings.catch_warnings(record=True) as w:
+ y = x * x
+ if torch.cuda.device_count() >= 2:
+ # DataParallel is calling the forward in different threads
+ # without progating TLS, so hooks should not be called here
+ _self.assertEqual(len(w), 0)
+ else:
+ # DataParallel only uses one thread
+ # so hooks should be called here
+ _self.assertEqual(len(w), 2)
+
+ x = torch.ones(5, 5, requires_grad=True)
+ model = torch.nn.DataParallel(Model())
+
+ with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
+ model(x)
+ with warnings.catch_warnings(record=True) as w:
+ y = x * x
+ # hooks should be called here
+ _self.assertEqual(len(w), 2)
+
def test_python_thread_in_middle(self):
# User might write a network that starts on one CPU thread, then runs its second half
# concurrently with other threads (either via python threading or fork/join calls),
"aten/src/ATen/native/mkldnn/Utils.cpp",
"aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp",
"aten/src/ATen/record_function.cpp",
+ "aten/src/ATen/SavedTensorHooks.cpp",
"aten/src/ATen/vulkan/Context.cpp",
]
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
+#include <ATen/SavedTensorHooks.h>
#include <torch/csrc/THP.h>
}
}
- std::mutex PyDefaultSavedVariableHooks::mutex_;
- PyObject* PyDefaultSavedVariableHooks::pack_hook_(nullptr);
- PyObject* PyDefaultSavedVariableHooks::unpack_hook_(nullptr);
-
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
- std::lock_guard<std::mutex> lock(mutex_);
+ PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
+ std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
"Setting default hooks but they have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
- pack_hook_ = pack_hook.release().ptr();
- unpack_hook_ = unpack_hook.release().ptr();
+ at::SavedTensorDefaultHooks::enable();
+ at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
}
void PyDefaultSavedVariableHooks::reset_hooks() {
- std::lock_guard<std::mutex> lock(mutex_);
+ PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
+ std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
- Py_XDECREF(pack_hook_);
- Py_XDECREF(unpack_hook_);
+ Py_XDECREF(pack_hook);
+ Py_XDECREF(unpack_hook);
}
- pack_hook_ = nullptr;
- unpack_hook_ = nullptr;
+ at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
- if (!pack_hook_ || !unpack_hook_) {
+ PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
+ std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
+ if (!pack_hook || !unpack_hook) {
return nullptr;
}
- std::lock_guard<std::mutex> lock(mutex_);
py::gil_scoped_acquire gil;
- py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
- py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
- return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook);
+ py::function pack_hook_ = py::reinterpret_borrow<py::function>(pack_hook);
+ py::function unpack_hook_ = py::reinterpret_borrow<py::function>(unpack_hook);
+ return std::make_unique<PySavedVariableHooks>(pack_hook_, unpack_hook_);
}
}}
static void set_hooks(py::function &pack_hook, py::function &unpack_hook);
static void reset_hooks();
static std::unique_ptr<SavedVariableHooks> get_hooks();
-
-private:
- static PyObject* pack_hook_;
- static PyObject* unpack_hook_;
-
- // Mutex to ensure that concurrent operations that modify default pack_hook_ and
- // unpack_hook_ are thread-safe.
- static std::mutex mutex_;
};
}}