remove special grad_mode tls handling (#63116)
authoralbanD <desmaison.alban@gmail.com>
Thu, 26 Aug 2021 14:48:20 +0000 (07:48 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 14:51:30 +0000 (07:51 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63116

This PR removes the special flag to disable grad mode tracking on the ThreadLocalState and replaces it with an explicit setter that users can use.
This allows to reduce complexity of ThreadLocalState.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30388098

Pulled By: albanD

fbshipit-source-id: 85641b3d711179fb78ff6a41ed077548dc821a2f

aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
torch/csrc/autograd/engine.cpp
torch/csrc/autograd/engine.h
torch/csrc/distributed/autograd/engine/dist_engine.cpp

index fc4b8fa..98c2519 100644 (file)
@@ -9,40 +9,26 @@
 
 namespace at {
 
-ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
+ThreadLocalState::ThreadLocalState()
     : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
       debug_info_(c10::ThreadLocalDebugInfo::current()),
       autograd_tls_(c10::AutogradState::get_tls_state()) {
   rf_tls_ = at::get_record_function_tls_();
   saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
 
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
-  keep_grad_mode_ = keep_grad_mode;
-#endif
   bumped_record_all_functions_ = at::checkRecordAllFunctions();
 }
 
+void ThreadLocalState::set_grad_mode(bool enabled) {
+  autograd_tls_.set_grad_mode(enabled);
+}
+
 /* static */
 void ThreadLocalState::setThreadLocalState(
     const ThreadLocalState& state) {
   // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
   // restore the dispatch key set TLS at the same time.
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
-  if (state.keep_grad_mode_) {
-    c10::AutogradState::set_tls_state(state.autograd_tls_);
-  } else {
-    auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
-                                        /* inference_mode */ state.autograd_tls_.get_inference_mode());
-    c10::AutogradState::set_tls_state(new_state);
-  }
-#else
-  // The mobile build explicitly ignore grad_mode but fails if we propagate
-  // its value across threads or set it to a fixed value.
-  // So we have to make sure the grad_mode value is not changed here.
-  auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
-                                      /* inference_mode */ state.autograd_tls_.get_inference_mode());
-  c10::AutogradState::set_tls_state(new_state);
-#endif
+  c10::AutogradState::set_tls_state(state.autograd_tls_);
 
   at::set_record_function_tls_(state.rf_tls_);
 
index 4942399..4114691 100644 (file)
@@ -16,10 +16,12 @@ class TORCH_API ThreadLocalState {
  public:
   // Saves the thread local variables' values and
   // returns them as a ThreadLocalState
-  // keep_grad_mode - whether grad mode has to be preserved
-  //  (e.g. not preserved when passing from forward pass into
-  //   the autograd engine, autograd engine takes care of grad mode)
-  ThreadLocalState(bool keep_grad_mode = true);
+  ThreadLocalState();
+
+  // set_grad_mode - force the value of the grad mode TLS in
+  //  the current state object. This is used for example in the
+  //  autograd engine.
+  void set_grad_mode(bool enabled);
 
   // Sets thread local variables in the current thread,
   // according to the thread boundary specified
@@ -35,10 +37,8 @@ class TORCH_API ThreadLocalState {
   // RecordFunction TLS
   RecordFunctionTLS rf_tls_;
 
+  // TLS for AutogradModes
   AutogradState autograd_tls_;
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
-  bool keep_grad_mode_ = true;
-#endif
 
   // TLS for saved tensors default hooks
   std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
index de2078d..acd7971 100644 (file)
@@ -411,7 +411,6 @@ auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
         // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
         // always saves ThreadLocalState without grad_mode.
         at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
-        AutoGradMode grad_mode(local_graph_task->grad_mode_);
 
         try {
           // The guard sets the thread_local current_graph_task on construction
@@ -580,7 +579,6 @@ void GraphTask::exec_post_processing() {
     // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
     // always saves ThreadLocalState without grad_mode.
     at::ThreadLocalStateGuard tls_guard(this->thread_locals_);
-    AutoGradMode grad_mode(this->grad_mode_);
 
     // WARNING: Don't use a range-for loop here because more callbacks may be
     // added in between callback calls, so iterators may become invalidated.
index 1731847..dd465f9 100644 (file)
@@ -53,9 +53,8 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {
   // true, it signals all threads to stop executing.
   std::atomic_bool has_error_{false};
   std::atomic_bool future_completed_{false};
-  // It is safe to read grad_mode_ and keep_graph_ without synchronization
+  // It is safe to read keep_graph_ without synchronization
   bool keep_graph_;
-  bool grad_mode_;
 
   // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
   // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
@@ -110,8 +109,9 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {
   // out of the GraphTask and are no longer valid.
   std::vector<Variable> captured_vars_;
 
-  at::ThreadLocalState thread_locals_ =
-      at::ThreadLocalState(/* keep_grad_mode */ false);
+  // Note: this field is not ready to be used until the proper `thread_locals_.set_grad_mode()`
+  // call in the constructor.
+  at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
 
   std::unordered_set<c10::Stream> leaf_streams;
 
@@ -180,12 +180,13 @@ struct GraphTask: std::enable_shared_from_this<GraphTask> {
       std::shared_ptr<ReadyQueue> cpu_ready_queue,
       bool exit_on_error = false)
       : keep_graph_(keep_graph),
-        grad_mode_(grad_mode),
         owner_(NO_DEVICE),
         reentrant_depth_(reentrant_depth),
         exit_on_error_(exit_on_error),
         cpu_ready_queue_(std::move(cpu_ready_queue)),
-        future_result_(c10::make_intrusive<at::ivalue::Future>(c10::ListType::create(c10::TensorType::get()))) {}
+        future_result_(c10::make_intrusive<at::ivalue::Future>(c10::ListType::create(c10::TensorType::get()))) {
+    thread_locals_.set_grad_mode(grad_mode);
+        }
  private:
   // run GraphTask post processing
   void exec_post_processing();
index 4a3b3ff..e6522c3 100644 (file)
@@ -360,7 +360,6 @@ void DistEngine::execute_graph_task_until_ready_queue_empty(
       }
       if (task.fn_ && !local_graph_task->has_error_.load()) {
         at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
-        AutoGradMode grad_mode(local_graph_task->grad_mode_);
         try {
           GraphTaskGuard guard(local_graph_task);
           engine_.evaluate_function(