New TLS to disable forward mode AD (#63117)
authoralbanD <desmaison.alban@gmail.com>
Fri, 27 Aug 2021 18:53:27 +0000 (11:53 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 18:59:24 +0000 (11:59 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63117

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30388097

Pulled By: albanD

fbshipit-source-id: f1bc777064645db1ff848bdd64af95bffb530984

c10/core/AutogradState.cpp
c10/core/AutogradState.h
c10/core/GradMode.cpp
c10/core/GradMode.h
c10/core/InferenceMode.h
torch/csrc/autograd/autograd_meta.cpp

index 9684a76..4667acb 100644 (file)
@@ -4,8 +4,10 @@ namespace c10 {
 
 namespace {
 // By default, grad mode is enabled and inference mode is disabled
-thread_local AutogradState autograd_state_tls =
-    AutogradState(/* grad_mode */ true, /* inference_mode */ false);
+thread_local AutogradState autograd_state_tls = AutogradState(
+    /* grad_mode */ true,
+    /* inference_mode */ false,
+    /* fw_grad_mode */ true);
 } // namespace
 
 AutogradState& AutogradState::get_tls_state() {
index 1447594..a1d13a4 100644 (file)
@@ -12,13 +12,19 @@ struct C10_API AutogradState {
   static AutogradState& get_tls_state();
   static void set_tls_state(AutogradState state);
 
-  AutogradState(bool grad_mode, bool inference_mode)
-      : grad_mode_(grad_mode), inference_mode_(inference_mode) {}
+  AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
+      : grad_mode_(grad_mode),
+        inference_mode_(inference_mode),
+        fw_grad_mode_(fw_grad_mode) {}
 
   void set_grad_mode(bool enabled) {
     grad_mode_ = enabled;
   }
 
+  void set_fw_grad_mode(bool enabled) {
+    fw_grad_mode_ = enabled;
+  }
+
   void set_inference_mode(bool enabled) {
     inference_mode_ = enabled;
   }
@@ -27,6 +33,10 @@ struct C10_API AutogradState {
     return grad_mode_;
   }
 
+  bool get_fw_grad_mode() const {
+    return fw_grad_mode_;
+  }
+
   bool get_inference_mode() const {
     return inference_mode_;
   }
@@ -34,6 +44,7 @@ struct C10_API AutogradState {
  private:
   bool grad_mode_ : 1;
   bool inference_mode_ : 1;
+  bool fw_grad_mode_ : 1;
 };
 
 } // namespace c10
index a5db198..c2ea869 100644 (file)
@@ -1,4 +1,3 @@
-#include <c10/core/AutogradState.h>
 #include <c10/core/GradMode.h>
 
 #include <stdexcept>
index 1168bb1..d83ff6d 100644 (file)
@@ -1,5 +1,6 @@
 #pragma once
 
+#include <c10/core/AutogradState.h>
 #include <c10/macros/Macros.h>
 
 namespace c10 {
@@ -27,4 +28,17 @@ struct TORCH_API NoGradGuard : public AutoGradMode {
   NoGradGuard() : AutoGradMode(/*enabled=*/false) {}
 };
 
+// A RAII, thread local (!) guard that enables or disables forward grad mode
+// upon construction, and sets it back to the original value upon destruction.
+struct TORCH_API AutoFwGradMode {
+  AutoFwGradMode(bool enabled)
+      : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) {
+    AutogradState::get_tls_state().set_fw_grad_mode(enabled);
+  }
+  ~AutoFwGradMode() {
+    AutogradState::get_tls_state().set_fw_grad_mode(prev_mode);
+  }
+  bool prev_mode;
+};
+
 } // namespace c10
index 9748d6e..704c43b 100644 (file)
@@ -53,10 +53,12 @@ struct TORCH_API InferenceMode {
   InferenceMode(bool enabled = true)
       : prev_mode(AutogradState::get_tls_state()),
         prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
-    // Enabling inference mode means disabling grad mode
-    // And disabling inference mode means enabling grad mode
-    AutogradState::set_tls_state(
-        AutogradState(/* grad_mode */ !enabled, /* inference_mode */ enabled));
+    // Enabling inference mode means disabling grad modes
+    // And disabling inference mode means enabling grad modes
+    AutogradState::set_tls_state(AutogradState(
+        /* grad_mode */ !enabled,
+        /* inference_mode */ enabled,
+        /* fw_grad_mode */ !enabled));
     DispatchKeySet included = enabled
         ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
         : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
index 248847f..f35c122 100644 (file)
@@ -183,6 +183,12 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self,
 }
 
 const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
+  // TLS that disables forward AD
+  // This is only used for custom Function implementation
+  if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
+    return ForwardGrad::undef_grad();
+  }
+
   // Ensure that concurent fw_grad() "reads" are thread safe
   std::lock_guard<std::mutex> lock(mutex_);