From 04108592a362848a5d3af4332f7628a14e312174 Mon Sep 17 00:00:00 2001 From: albanD Date: Fri, 27 Aug 2021 11:53:27 -0700 Subject: [PATCH] New TLS to disable forward mode AD (#63117) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63117 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D30388097 Pulled By: albanD fbshipit-source-id: f1bc777064645db1ff848bdd64af95bffb530984 --- c10/core/AutogradState.cpp | 6 ++++-- c10/core/AutogradState.h | 15 +++++++++++++-- c10/core/GradMode.cpp | 1 - c10/core/GradMode.h | 14 ++++++++++++++ c10/core/InferenceMode.h | 10 ++++++---- torch/csrc/autograd/autograd_meta.cpp | 6 ++++++ 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/c10/core/AutogradState.cpp b/c10/core/AutogradState.cpp index 9684a76..4667acb 100644 --- a/c10/core/AutogradState.cpp +++ b/c10/core/AutogradState.cpp @@ -4,8 +4,10 @@ namespace c10 { namespace { // By default, grad mode is enabled and inference mode is disabled -thread_local AutogradState autograd_state_tls = - AutogradState(/* grad_mode */ true, /* inference_mode */ false); +thread_local AutogradState autograd_state_tls = AutogradState( + /* grad_mode */ true, + /* inference_mode */ false, + /* fw_grad_mode */ true); } // namespace AutogradState& AutogradState::get_tls_state() { diff --git a/c10/core/AutogradState.h b/c10/core/AutogradState.h index 1447594..a1d13a4 100644 --- a/c10/core/AutogradState.h +++ b/c10/core/AutogradState.h @@ -12,13 +12,19 @@ struct C10_API AutogradState { static AutogradState& get_tls_state(); static void set_tls_state(AutogradState state); - AutogradState(bool grad_mode, bool inference_mode) - : grad_mode_(grad_mode), inference_mode_(inference_mode) {} + AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode) + : grad_mode_(grad_mode), + inference_mode_(inference_mode), + fw_grad_mode_(fw_grad_mode) {} void set_grad_mode(bool enabled) { grad_mode_ = enabled; } + void set_fw_grad_mode(bool enabled) { + fw_grad_mode_ = enabled; + } + void set_inference_mode(bool enabled) { inference_mode_ = enabled; } @@ -27,6 +33,10 @@ struct C10_API AutogradState { return grad_mode_; } + bool get_fw_grad_mode() const { + return fw_grad_mode_; + } + bool get_inference_mode() const { return inference_mode_; } @@ -34,6 +44,7 @@ struct C10_API AutogradState { private: bool grad_mode_ : 1; bool inference_mode_ : 1; + bool fw_grad_mode_ : 1; }; } // namespace c10 diff --git a/c10/core/GradMode.cpp b/c10/core/GradMode.cpp index a5db198..c2ea869 100644 --- a/c10/core/GradMode.cpp +++ b/c10/core/GradMode.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/c10/core/GradMode.h b/c10/core/GradMode.h index 1168bb1..d83ff6d 100644 --- a/c10/core/GradMode.h +++ b/c10/core/GradMode.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace c10 { @@ -27,4 +28,17 @@ struct TORCH_API NoGradGuard : public AutoGradMode { NoGradGuard() : AutoGradMode(/*enabled=*/false) {} }; +// A RAII, thread local (!) guard that enables or disables forward grad mode +// upon construction, and sets it back to the original value upon destruction. +struct TORCH_API AutoFwGradMode { + AutoFwGradMode(bool enabled) + : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { + AutogradState::get_tls_state().set_fw_grad_mode(enabled); + } + ~AutoFwGradMode() { + AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); + } + bool prev_mode; +}; + } // namespace c10 diff --git a/c10/core/InferenceMode.h b/c10/core/InferenceMode.h index 9748d6e..704c43b 100644 --- a/c10/core/InferenceMode.h +++ b/c10/core/InferenceMode.h @@ -53,10 +53,12 @@ struct TORCH_API InferenceMode { InferenceMode(bool enabled = true) : prev_mode(AutogradState::get_tls_state()), prev_keyset(c10::impl::tls_local_dispatch_key_set()) { - // Enabling inference mode means disabling grad mode - // And disabling inference mode means enabling grad mode - AutogradState::set_tls_state( - AutogradState(/* grad_mode */ !enabled, /* inference_mode */ enabled)); + // Enabling inference mode means disabling grad modes + // And disabling inference mode means enabling grad modes + AutogradState::set_tls_state(AutogradState( + /* grad_mode */ !enabled, + /* inference_mode */ enabled, + /* fw_grad_mode */ !enabled)); DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView) : prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView); diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index 248847f..f35c122 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -183,6 +183,12 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, } const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const { + // TLS that disables forward AD + // This is only used for custom Function implementation + if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) { + return ForwardGrad::undef_grad(); + } + // Ensure that concurent fw_grad() "reads" are thread safe std::lock_guard lock(mutex_); -- 2.7.4