From 733755f72ca15feef8deeb512925639ef15f92d7 Mon Sep 17 00:00:00 2001 From: albanD Date: Thu, 26 Aug 2021 07:48:20 -0700 Subject: [PATCH] remove special grad_mode tls handling (#63116) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63116 This PR removes the special flag to disable grad mode tracking on the ThreadLocalState and replaces it with an explicit setter that users can use. This allows to reduce complexity of ThreadLocalState. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D30388098 Pulled By: albanD fbshipit-source-id: 85641b3d711179fb78ff6a41ed077548dc821a2f --- aten/src/ATen/ThreadLocalState.cpp | 26 +++++----------------- aten/src/ATen/ThreadLocalState.h | 14 ++++++------ torch/csrc/autograd/engine.cpp | 2 -- torch/csrc/autograd/engine.h | 13 ++++++----- .../distributed/autograd/engine/dist_engine.cpp | 1 - 5 files changed, 20 insertions(+), 36 deletions(-) diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index fc4b8fa..98c2519 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -9,40 +9,26 @@ namespace at { -ThreadLocalState::ThreadLocalState(bool keep_grad_mode) +ThreadLocalState::ThreadLocalState() : dispatch_key_(c10::impl::tls_local_dispatch_key_set()), debug_info_(c10::ThreadLocalDebugInfo::current()), autograd_tls_(c10::AutogradState::get_tls_state()) { rf_tls_ = at::get_record_function_tls_(); saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - keep_grad_mode_ = keep_grad_mode; -#endif bumped_record_all_functions_ = at::checkRecordAllFunctions(); } +void ThreadLocalState::set_grad_mode(bool enabled) { + autograd_tls_.set_grad_mode(enabled); +} + /* static */ void ThreadLocalState::setThreadLocalState( const ThreadLocalState& state) { // Note that setting the InferenceMode TLS in this function is ONLY ok because we always // restore the dispatch key set TLS at the same time. -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - if (state.keep_grad_mode_) { - c10::AutogradState::set_tls_state(state.autograd_tls_); - } else { - auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(), - /* inference_mode */ state.autograd_tls_.get_inference_mode()); - c10::AutogradState::set_tls_state(new_state); - } -#else - // The mobile build explicitly ignore grad_mode but fails if we propagate - // its value across threads or set it to a fixed value. - // So we have to make sure the grad_mode value is not changed here. - auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(), - /* inference_mode */ state.autograd_tls_.get_inference_mode()); - c10::AutogradState::set_tls_state(new_state); -#endif + c10::AutogradState::set_tls_state(state.autograd_tls_); at::set_record_function_tls_(state.rf_tls_); diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 4942399..4114691 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -16,10 +16,12 @@ class TORCH_API ThreadLocalState { public: // Saves the thread local variables' values and // returns them as a ThreadLocalState - // keep_grad_mode - whether grad mode has to be preserved - // (e.g. not preserved when passing from forward pass into - // the autograd engine, autograd engine takes care of grad mode) - ThreadLocalState(bool keep_grad_mode = true); + ThreadLocalState(); + + // set_grad_mode - force the value of the grad mode TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_grad_mode(bool enabled); // Sets thread local variables in the current thread, // according to the thread boundary specified @@ -35,10 +37,8 @@ class TORCH_API ThreadLocalState { // RecordFunction TLS RecordFunctionTLS rf_tls_; + // TLS for AutogradModes AutogradState autograd_tls_; -#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) - bool keep_grad_mode_ = true; -#endif // TLS for saved tensors default hooks std::pair saved_tensors_default_hooks_; diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index de2078d..acd7971 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -411,7 +411,6 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask // always saves ThreadLocalState without grad_mode. at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); - AutoGradMode grad_mode(local_graph_task->grad_mode_); try { // The guard sets the thread_local current_graph_task on construction @@ -580,7 +579,6 @@ void GraphTask::exec_post_processing() { // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask // always saves ThreadLocalState without grad_mode. at::ThreadLocalStateGuard tls_guard(this->thread_locals_); - AutoGradMode grad_mode(this->grad_mode_); // WARNING: Don't use a range-for loop here because more callbacks may be // added in between callback calls, so iterators may become invalidated. diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 1731847..dd465f9 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -53,9 +53,8 @@ struct GraphTask: std::enable_shared_from_this { // true, it signals all threads to stop executing. std::atomic_bool has_error_{false}; std::atomic_bool future_completed_{false}; - // It is safe to read grad_mode_ and keep_graph_ without synchronization + // It is safe to read keep_graph_ without synchronization bool keep_graph_; - bool grad_mode_; // To protect reads/writes to not_ready_, dependencies_, captured_vars_, // has_error_, future_result_, cpu_ready_queue_, and leaf_streams. @@ -110,8 +109,9 @@ struct GraphTask: std::enable_shared_from_this { // out of the GraphTask and are no longer valid. std::vector captured_vars_; - at::ThreadLocalState thread_locals_ = - at::ThreadLocalState(/* keep_grad_mode */ false); + // Note: this field is not ready to be used until the proper `thread_locals_.set_grad_mode()` + // call in the constructor. + at::ThreadLocalState thread_locals_ = at::ThreadLocalState(); std::unordered_set leaf_streams; @@ -180,12 +180,13 @@ struct GraphTask: std::enable_shared_from_this { std::shared_ptr cpu_ready_queue, bool exit_on_error = false) : keep_graph_(keep_graph), - grad_mode_(grad_mode), owner_(NO_DEVICE), reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), cpu_ready_queue_(std::move(cpu_ready_queue)), - future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) {} + future_result_(c10::make_intrusive(c10::ListType::create(c10::TensorType::get()))) { + thread_locals_.set_grad_mode(grad_mode); + } private: // run GraphTask post processing void exec_post_processing(); diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 4a3b3ff..e6522c3 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -360,7 +360,6 @@ void DistEngine::execute_graph_task_until_ready_queue_empty( } if (task.fn_ && !local_graph_task->has_error_.load()) { at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); - AutoGradMode grad_mode(local_graph_task->grad_mode_); try { GraphTaskGuard guard(local_graph_task); engine_.evaluate_function( -- 2.7.4