ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
: dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
debug_info_(c10::ThreadLocalDebugInfo::current()),
- inference_mode_enabled_(c10::InferenceMode::is_enabled()) {
+ autograd_tls_(c10::AutogradState::get_tls_state()) {
rf_tls_ = at::get_record_function_tls_();
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
keep_grad_mode_ = keep_grad_mode;
- if (keep_grad_mode_) {
- grad_mode_enabled_ = GradMode::is_enabled();
- }
#endif
bumped_record_all_functions_ = at::checkRecordAllFunctions();
}
/* static */
void ThreadLocalState::setThreadLocalState(
const ThreadLocalState& state) {
+ // Note that setting the InferenceMode TLS in this function is ONLY ok because we always
+ // restore the dispatch key set TLS at the same time.
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
if (state.keep_grad_mode_) {
- GradMode::set_enabled(state.grad_mode_enabled_);
+ c10::AutogradState::set_tls_state(state.autograd_tls_);
+ } else {
+ auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
+ /* inference_mode */ state.autograd_tls_.get_inference_mode());
+ c10::AutogradState::set_tls_state(new_state);
}
+#else
+ // The mobile build explicitly ignore grad_mode but fails if we propagate
+ // its value across threads or set it to a fixed value.
+ // So we have to make sure the grad_mode value is not changed here.
+ auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
+ /* inference_mode */ state.autograd_tls_.get_inference_mode());
+ c10::AutogradState::set_tls_state(new_state);
#endif
at::set_record_function_tls_(state.rf_tls_);
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
-
- c10::InferenceMode::_set_enabled(state.inference_mode_enabled_);
}
} // namespace at
// RecordFunction TLS
RecordFunctionTLS rf_tls_;
+ AutogradState autograd_tls_;
#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
bool keep_grad_mode_ = true;
- bool grad_mode_enabled_;
#endif
- // TLS for InferenceMode
- bool inference_mode_enabled_;
-
// TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
--- /dev/null
+#include <c10/core/AutogradState.h>
+
+namespace c10 {
+
+namespace {
+// By default, grad mode is enabled and inference mode is disabled
+thread_local AutogradState autograd_state_tls =
+ AutogradState(/* grad_mode */ true, /* inference_mode */ false);
+} // namespace
+
+AutogradState& AutogradState::get_tls_state() {
+ return autograd_state_tls;
+}
+
+void AutogradState::set_tls_state(AutogradState state) {
+ autograd_state_tls = state;
+}
+
+} // namespace c10
--- /dev/null
+#pragma once
+
+#include <c10/macros/Macros.h>
+
+#include <cstdint>
+
+namespace c10 {
+
+// Structure used to pack all the thread local boolean
+// flags used by autograd
+struct TORCH_API AutogradState {
+ static AutogradState& get_tls_state();
+ static void set_tls_state(AutogradState state);
+
+ AutogradState(bool grad_mode, bool inference_mode)
+ : grad_mode_(grad_mode), inference_mode_(inference_mode) {}
+
+ void set_grad_mode(bool enabled) {
+ grad_mode_ = enabled;
+ }
+
+ void set_inference_mode(bool enabled) {
+ inference_mode_ = enabled;
+ }
+
+ bool get_grad_mode() const {
+ return grad_mode_;
+ }
+
+ bool get_inference_mode() const {
+ return inference_mode_;
+ }
+
+ private:
+ bool grad_mode_ : 1;
+ bool inference_mode_ : 1;
+};
+
+} // namespace c10
+#include <c10/core/AutogradState.h>
#include <c10/core/GradMode.h>
#include <stdexcept>
namespace c10 {
-thread_local bool GradMode_enabled = true;
-
bool GradMode::is_enabled() {
- return GradMode_enabled;
+ return AutogradState::get_tls_state().get_grad_mode();
}
void GradMode::set_enabled(bool enabled) {
- GradMode_enabled = enabled;
+ AutogradState::get_tls_state().set_grad_mode(enabled);
}
} // namespace c10
#include <stdexcept>
namespace c10 {
-thread_local bool InferenceMode_enabled = false;
-
// Invariant:
// is_enabled() ==
// !c10::impl::tls_is_dispatch_key_included(DispatchKey::ADInplaceOrView);
// InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor)
// so it worths a separate TLS to skip the DispatchKeySet check.
bool InferenceMode::is_enabled() {
- return InferenceMode_enabled;
-}
-
-void InferenceMode::_set_enabled(bool enabled) {
- InferenceMode_enabled = enabled;
+ return AutogradState::get_tls_state().get_inference_mode();
}
} // namespace c10
#pragma once
+#include <c10/core/AutogradState.h>
#include <c10/core/GradMode.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/macros/Macros.h>
// are applicable to InferenceMode as well, e.g.
// `tensorTypeInCurrentExecutionContext` in interpreter.cpp.
InferenceMode(bool enabled = true)
- : prev_mode(InferenceMode::is_enabled()),
- prev_keyset(c10::impl::tls_local_dispatch_key_set()),
- grad_mode(at::AutoGradMode(!enabled)) {
- _set_enabled(enabled);
+ : prev_mode(AutogradState::get_tls_state()),
+ prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
+ // Enabling inference mode means disabling grad mode
+ // And disabling inference mode means enabling grad mode
+ AutogradState::set_tls_state(
+ AutogradState(/* grad_mode */ !enabled, /* inference_mode */ enabled));
DispatchKeySet included = enabled
? prev_keyset.included_.remove(c10::DispatchKey::ADInplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::ADInplaceOrView);
}
~InferenceMode() {
- _set_enabled(prev_mode);
+ AutogradState::set_tls_state(prev_mode);
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
}
static bool is_enabled();
- // _set_enabled() is not user facing and should be only used in
- // ThreadLocalState.cpp.
- static void _set_enabled(bool enabled);
private:
- bool prev_mode;
+ AutogradState prev_mode;
c10::impl::LocalDispatchKeySet prev_keyset;
- at::AutoGradMode grad_mode;
};
} // namespace c10