From ab954cb0d176a5632f123ac19d9469e6f863d39a Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 25 Aug 2021 11:07:24 -0700 Subject: [PATCH] clean up engine.cpp thread state (#63115) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63115 This actually changes: - callbacks now run with proper grad mode even in worker threads - graphtask's Future callbacks now run with proper TLS when erroring out from a worker thread Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D30388100 Pulled By: albanD fbshipit-source-id: 7ae9c461c2f0040548dd9e1e314f25e8da0c2e67 --- torch/csrc/autograd/engine.cpp | 11 ++++++----- .../csrc/distributed/autograd/engine/dist_engine.cpp | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 252a74b4c0..de2078d2d6 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -407,7 +407,12 @@ auto Engine::thread_main(const std::shared_ptr& graph_task) -> void { } if (task.fn_ && !local_graph_task->has_error_.load()) { + // Set the ThreadLocalState before calling the function. + // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask + // always saves ThreadLocalState without grad_mode. + at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); AutoGradMode grad_mode(local_graph_task->grad_mode_); + try { // The guard sets the thread_local current_graph_task on construction // and restores it on exit. The current_graph_task variable helps @@ -575,6 +580,7 @@ void GraphTask::exec_post_processing() { // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask // always saves ThreadLocalState without grad_mode. at::ThreadLocalStateGuard tls_guard(this->thread_locals_); + AutoGradMode grad_mode(this->grad_mode_); // WARNING: Don't use a range-for loop here because more callbacks may be // added in between callback calls, so iterators may become invalidated. @@ -764,11 +770,6 @@ void Engine::evaluate_function( Node* func, InputBuffer& inputs, const std::shared_ptr& cpu_ready_queue) { - // Set the ThreadLocalState before calling the function. - // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask - // always saves ThreadLocalState without grad_mode. - at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_); - // The InputBuffer::adds that supplied incoming grads took pains to // ensure they're safe to consume in the context of the present // func's stream (if applicable). So we guard onto that stream diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 76f2eaebe5..4a3b3fff2e 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -359,6 +359,7 @@ void DistEngine::execute_graph_task_until_ready_queue_empty( continue; } if (task.fn_ && !local_graph_task->has_error_.load()) { + at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_); AutoGradMode grad_mode(local_graph_task->grad_mode_); try { GraphTaskGuard guard(local_graph_task); -- 2.34.1