Add a common autograd TLS state (#63114)
authoralbanD <desmaison.alban@gmail.com>
Tue, 24 Aug 2021 13:52:38 +0000 (06:52 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 24 Aug 2021 13:54:02 +0000 (06:54 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63114

This PR collapses the GradMode and InferenceMode thread local booleans into a single thread local uint8.
This helps reducing the number of thread local variable accesses done when we propagate ThreadLocalStates.

Note that this is even more beneficial as we will add a forward mode AD TLS (similar to GradMode) higher in this stack and this new structure should reduce the perf impact of adding this new TLS.

Here is the full benchmark result between master and the top of this stack: https://gist.github.com/albanD/e421101e9ed344e94999bef3a54bf0f3
tl;dr: give a benefit in most cases. It is never detrimental.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30388099

Pulled By: albanD

fbshipit-source-id: 8e03f940150ff063c2edd792733663413ae2f486

aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
c10/core/AutogradState.cpp [new file with mode: 0644]
c10/core/AutogradState.h [new file with mode: 0644]
c10/core/GradMode.cpp
c10/core/InferenceMode.cpp
c10/core/InferenceMode.h

index ba7be1a..fc4b8fa 100644 (file)
@@ -12,15 +12,12 @@ namespace at {
 ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
     : dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
       debug_info_(c10::ThreadLocalDebugInfo::current()),
-      inference_mode_enabled_(c10::InferenceMode::is_enabled()) {
+      autograd_tls_(c10::AutogradState::get_tls_state()) {
   rf_tls_ = at::get_record_function_tls_();
   saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
 
 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
   keep_grad_mode_ = keep_grad_mode;
-  if (keep_grad_mode_) {
-    grad_mode_enabled_ = GradMode::is_enabled();
-  }
 #endif
   bumped_record_all_functions_ = at::checkRecordAllFunctions();
 }
@@ -28,10 +25,23 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
 /* static */
 void ThreadLocalState::setThreadLocalState(
     const ThreadLocalState& state) {
+  // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
+  // restore the dispatch key set TLS at the same time.
 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
   if (state.keep_grad_mode_) {
-    GradMode::set_enabled(state.grad_mode_enabled_);
+    c10::AutogradState::set_tls_state(state.autograd_tls_);
+  } else {
+    auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
+                                        /* inference_mode */ state.autograd_tls_.get_inference_mode());
+    c10::AutogradState::set_tls_state(new_state);
   }
+#else
+  // The mobile build explicitly ignore grad_mode but fails if we propagate
+  // its value across threads or set it to a fixed value.
+  // So we have to make sure the grad_mode value is not changed here.
+  auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
+                                      /* inference_mode */ state.autograd_tls_.get_inference_mode());
+  c10::AutogradState::set_tls_state(new_state);
 #endif
 
   at::set_record_function_tls_(state.rf_tls_);
@@ -43,8 +53,6 @@ void ThreadLocalState::setThreadLocalState(
   c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
 
   c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
-
-  c10::InferenceMode::_set_enabled(state.inference_mode_enabled_);
 }
 
 } // namespace at
index f30f5e3..4942399 100644 (file)
@@ -35,14 +35,11 @@ class TORCH_API ThreadLocalState {
   // RecordFunction TLS
   RecordFunctionTLS rf_tls_;
 
+  AutogradState autograd_tls_;
 #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
   bool keep_grad_mode_ = true;
-  bool grad_mode_enabled_;
 #endif
 
-  // TLS for InferenceMode
-  bool inference_mode_enabled_;
-
   // TLS for saved tensors default hooks
   std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
 
diff --git a/c10/core/AutogradState.cpp b/c10/core/AutogradState.cpp
new file mode 100644 (file)
index 0000000..9684a76
--- /dev/null
@@ -0,0 +1,19 @@
+#include <c10/core/AutogradState.h>
+
+namespace c10 {
+
+namespace {
+// By default, grad mode is enabled and inference mode is disabled
+thread_local AutogradState autograd_state_tls =
+    AutogradState(/* grad_mode */ true, /* inference_mode */ false);
+} // namespace
+
+AutogradState& AutogradState::get_tls_state() {
+  return autograd_state_tls;
+}
+
+void AutogradState::set_tls_state(AutogradState state) {
+  autograd_state_tls = state;
+}
+
+} // namespace c10
diff --git a/c10/core/AutogradState.h b/c10/core/AutogradState.h
new file mode 100644 (file)
index 0000000..83ea360
--- /dev/null
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+
+#include <cstdint>
+
+namespace c10 {
+
+// Structure used to pack all the thread local boolean
+// flags used by autograd
+struct TORCH_API AutogradState {
+  static AutogradState& get_tls_state();
+  static void set_tls_state(AutogradState state);
+
+  AutogradState(bool grad_mode, bool inference_mode)
+      : grad_mode_(grad_mode), inference_mode_(inference_mode) {}
+
+  void set_grad_mode(bool enabled) {
+    grad_mode_ = enabled;
+  }
+
+  void set_inference_mode(bool enabled) {
+    inference_mode_ = enabled;
+  }
+
+  bool get_grad_mode() const {
+    return grad_mode_;
+  }
+
+  bool get_inference_mode() const {
+    return inference_mode_;
+  }
+
+ private:
+  bool grad_mode_ : 1;
+  bool inference_mode_ : 1;
+};
+
+} // namespace c10
index 32747a6..a5db198 100644 (file)
@@ -1,16 +1,15 @@
+#include <c10/core/AutogradState.h>
 #include <c10/core/GradMode.h>
 
 #include <stdexcept>
 
 namespace c10 {
 
-thread_local bool GradMode_enabled = true;
-
 bool GradMode::is_enabled() {
-  return GradMode_enabled;
+  return AutogradState::get_tls_state().get_grad_mode();
 }
 
 void GradMode::set_enabled(bool enabled) {
-  GradMode_enabled = enabled;
+  AutogradState::get_tls_state().set_grad_mode(enabled);
 }
 } // namespace c10
index b588ab4..59eca76 100644 (file)
@@ -2,18 +2,12 @@
 #include <stdexcept>
 
 namespace c10 {
-thread_local bool InferenceMode_enabled = false;
-
 // Invariant:
 //   is_enabled() ==
 //   !c10::impl::tls_is_dispatch_key_included(DispatchKey::ADInplaceOrView);
 // InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor)
 // so it worths a separate TLS to skip the DispatchKeySet check.
 bool InferenceMode::is_enabled() {
-  return InferenceMode_enabled;
-}
-
-void InferenceMode::_set_enabled(bool enabled) {
-  InferenceMode_enabled = enabled;
+  return AutogradState::get_tls_state().get_inference_mode();
 }
 } // namespace c10
index 7a9c2c5..9748d6e 100644 (file)
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <c10/core/AutogradState.h>
 #include <c10/core/GradMode.h>
 #include <c10/core/impl/LocalDispatchKeySet.h>
 #include <c10/macros/Macros.h>
@@ -50,10 +51,12 @@ struct TORCH_API InferenceMode {
   //    are applicable to InferenceMode as well, e.g.
   //    `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
   InferenceMode(bool enabled = true)
-      : prev_mode(InferenceMode::is_enabled()),
-        prev_keyset(c10::impl::tls_local_dispatch_key_set()),
-        grad_mode(at::AutoGradMode(!enabled)) {
-    _set_enabled(enabled);
+      : prev_mode(AutogradState::get_tls_state()),
+        prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
+    // Enabling inference mode means disabling grad mode
+    // And disabling inference mode means enabling grad mode
+    AutogradState::set_tls_state(
+        AutogradState(/* grad_mode */ !enabled, /* inference_mode */ enabled));
     DispatchKeySet included = enabled
         ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
         : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
@@ -67,17 +70,13 @@ struct TORCH_API InferenceMode {
   }
 
   ~InferenceMode() {
-    _set_enabled(prev_mode);
+    AutogradState::set_tls_state(prev_mode);
     c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
   }
   static bool is_enabled();
-  // _set_enabled() is not user facing and should be only used in
-  // ThreadLocalState.cpp.
-  static void _set_enabled(bool enabled);
 
  private:
-  bool prev_mode;
+  AutogradState prev_mode;
   c10::impl::LocalDispatchKeySet prev_keyset;
-  at::AutoGradMode grad_mode;
 };
 } // namespace c10