namespace at {
-ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
+ThreadLocalState::ThreadLocalState()
: dispatch_key_(c10::impl::tls_local_dispatch_key_set()),
debug_info_(c10::ThreadLocalDebugInfo::current()),
autograd_tls_(c10::AutogradState::get_tls_state()) {
rf_tls_ = at::get_record_function_tls_();
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
- keep_grad_mode_ = keep_grad_mode;
-#endif
bumped_record_all_functions_ = at::checkRecordAllFunctions();
}
+void ThreadLocalState::set_grad_mode(bool enabled) {
+ autograd_tls_.set_grad_mode(enabled);
+}
+
/* static */
void ThreadLocalState::setThreadLocalState(
const ThreadLocalState& state) {
// Note that setting the InferenceMode TLS in this function is ONLY ok because we always
// restore the dispatch key set TLS at the same time.
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
- if (state.keep_grad_mode_) {
- c10::AutogradState::set_tls_state(state.autograd_tls_);
- } else {
- auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
- /* inference_mode */ state.autograd_tls_.get_inference_mode());
- c10::AutogradState::set_tls_state(new_state);
- }
-#else
- // The mobile build explicitly ignore grad_mode but fails if we propagate
- // its value across threads or set it to a fixed value.
- // So we have to make sure the grad_mode value is not changed here.
- auto new_state = c10::AutogradState(/* grad_mode */ c10::AutogradState::get_tls_state().get_grad_mode(),
- /* inference_mode */ state.autograd_tls_.get_inference_mode());
- c10::AutogradState::set_tls_state(new_state);
-#endif
+ c10::AutogradState::set_tls_state(state.autograd_tls_);
at::set_record_function_tls_(state.rf_tls_);
public:
// Saves the thread local variables' values and
// returns them as a ThreadLocalState
- // keep_grad_mode - whether grad mode has to be preserved
- // (e.g. not preserved when passing from forward pass into
- // the autograd engine, autograd engine takes care of grad mode)
- ThreadLocalState(bool keep_grad_mode = true);
+ ThreadLocalState();
+
+ // set_grad_mode - force the value of the grad mode TLS in
+ // the current state object. This is used for example in the
+ // autograd engine.
+ void set_grad_mode(bool enabled);
// Sets thread local variables in the current thread,
// according to the thread boundary specified
// RecordFunction TLS
RecordFunctionTLS rf_tls_;
+ // TLS for AutogradModes
AutogradState autograd_tls_;
-#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
- bool keep_grad_mode_ = true;
-#endif
// TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
- AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
// The guard sets the thread_local current_graph_task on construction
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(this->thread_locals_);
- AutoGradMode grad_mode(this->grad_mode_);
// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
// true, it signals all threads to stop executing.
std::atomic_bool has_error_{false};
std::atomic_bool future_completed_{false};
- // It is safe to read grad_mode_ and keep_graph_ without synchronization
+ // It is safe to read keep_graph_ without synchronization
bool keep_graph_;
- bool grad_mode_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
// out of the GraphTask and are no longer valid.
std::vector<Variable> captured_vars_;
- at::ThreadLocalState thread_locals_ =
- at::ThreadLocalState(/* keep_grad_mode */ false);
+ // Note: this field is not ready to be used until the proper `thread_locals_.set_grad_mode()`
+ // call in the constructor.
+ at::ThreadLocalState thread_locals_ = at::ThreadLocalState();
std::unordered_set<c10::Stream> leaf_streams;
std::shared_ptr<ReadyQueue> cpu_ready_queue,
bool exit_on_error = false)
: keep_graph_(keep_graph),
- grad_mode_(grad_mode),
owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error),
cpu_ready_queue_(std::move(cpu_ready_queue)),
- future_result_(c10::make_intrusive<at::ivalue::Future>(c10::ListType::create(c10::TensorType::get()))) {}
+ future_result_(c10::make_intrusive<at::ivalue::Future>(c10::ListType::create(c10::TensorType::get()))) {
+ thread_locals_.set_grad_mode(grad_mode);
+ }
private:
// run GraphTask post processing
void exec_post_processing();
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);
- AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
GraphTaskGuard guard(local_graph_task);
engine_.evaluate_function(