namespace {
// By default, grad mode is enabled and inference mode is disabled
-thread_local AutogradState autograd_state_tls =
- AutogradState(/* grad_mode */ true, /* inference_mode */ false);
+thread_local AutogradState autograd_state_tls = AutogradState(
+ /* grad_mode */ true,
+ /* inference_mode */ false,
+ /* fw_grad_mode */ true);
} // namespace
AutogradState& AutogradState::get_tls_state() {
static AutogradState& get_tls_state();
static void set_tls_state(AutogradState state);
- AutogradState(bool grad_mode, bool inference_mode)
- : grad_mode_(grad_mode), inference_mode_(inference_mode) {}
+ AutogradState(bool grad_mode, bool inference_mode, bool fw_grad_mode)
+ : grad_mode_(grad_mode),
+ inference_mode_(inference_mode),
+ fw_grad_mode_(fw_grad_mode) {}
void set_grad_mode(bool enabled) {
grad_mode_ = enabled;
}
+ void set_fw_grad_mode(bool enabled) {
+ fw_grad_mode_ = enabled;
+ }
+
void set_inference_mode(bool enabled) {
inference_mode_ = enabled;
}
return grad_mode_;
}
+ bool get_fw_grad_mode() const {
+ return fw_grad_mode_;
+ }
+
bool get_inference_mode() const {
return inference_mode_;
}
private:
bool grad_mode_ : 1;
bool inference_mode_ : 1;
+ bool fw_grad_mode_ : 1;
};
} // namespace c10
-#include <c10/core/AutogradState.h>
#include <c10/core/GradMode.h>
#include <stdexcept>
#pragma once
+#include <c10/core/AutogradState.h>
#include <c10/macros/Macros.h>
namespace c10 {
NoGradGuard() : AutoGradMode(/*enabled=*/false) {}
};
+// A RAII, thread local (!) guard that enables or disables forward grad mode
+// upon construction, and sets it back to the original value upon destruction.
+struct TORCH_API AutoFwGradMode {
+ AutoFwGradMode(bool enabled)
+ : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) {
+ AutogradState::get_tls_state().set_fw_grad_mode(enabled);
+ }
+ ~AutoFwGradMode() {
+ AutogradState::get_tls_state().set_fw_grad_mode(prev_mode);
+ }
+ bool prev_mode;
+};
+
} // namespace c10
InferenceMode(bool enabled = true)
: prev_mode(AutogradState::get_tls_state()),
prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
- // Enabling inference mode means disabling grad mode
- // And disabling inference mode means enabling grad mode
- AutogradState::set_tls_state(
- AutogradState(/* grad_mode */ !enabled, /* inference_mode */ enabled));
+ // Enabling inference mode means disabling grad modes
+ // And disabling inference mode means enabling grad modes
+ AutogradState::set_tls_state(AutogradState(
+ /* grad_mode */ !enabled,
+ /* inference_mode */ enabled,
+ /* fw_grad_mode */ !enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
}
const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
+ // TLS that disables forward AD
+ // This is only used for custom Function implementation
+ if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
+ return ForwardGrad::undef_grad();
+ }
+
// Ensure that concurent fw_grad() "reads" are thread safe
std::lock_guard<std::mutex> lock(mutex_);