[Deploy] Avoid use-after-free during autograd shutdown (#64620)
authorDon Jang <djang@fb.com>
Mon, 13 Sep 2021 19:41:50 +0000 (12:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 19:43:10 +0000 (12:43 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64620

`autograd` extension module's shutdown logic destructs `PyThreadState` by `pybind11::gil_scoped_acquire` using the RAII pattern.

The problem is that torch.deploy also destructs `PyThreadState` as part of its shutdown process (https://www.internalfb.com/phabricator/paste/view/P456363738), causing double destruction, use-after-free.

This change adds `defined(USE_DEPLOY)` as a special case to avoid destruction of `PyThreadState` to the existing special treatment for  `IS_PYTHON_3_9_PLUS`.

Test Plan: Added `TorchpyTest.Autograd` unittest to ensure that torch.deploy can create multiple instances that use autograd without causing a crash.

Reviewed By: albanD

Differential Revision: D30779080

fbshipit-source-id: 4de3283cc2d394acc9b8141c17cacbfab5eea052

torch/csrc/autograd/python_engine.cpp
torch/csrc/deploy/test_deploy.cpp

index 20078bc..6d6c54b 100644 (file)
@@ -67,7 +67,7 @@ void PythonEngine::thread_init(int device, const std::shared_ptr<ReadyQueue>& re
   // Create a PyThreadState, but release the GIL. This lets pybind11::gil_scoped_acquire calls
   // inside thread_main acquire the GIL without having to create a new
   // PyThreadState each time.
-#ifdef IS_PYTHON_3_9_PLUS
+#if defined(IS_PYTHON_3_9_PLUS) || defined(USE_DEPLOY)
   auto gil = std::make_unique<pybind11::gil_scoped_acquire>();
 #else
   pybind11::gil_scoped_acquire gil;
@@ -80,11 +80,12 @@ void PythonEngine::thread_init(int device, const std::shared_ptr<ReadyQueue>& re
     decrement_non_reentrant_thread_count();
   }
 
-#ifdef IS_PYTHON_3_9_PLUS
+#if defined(IS_PYTHON_3_9_PLUS) || defined(USE_DEPLOY)
   // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if runtime is finalizing
   if (!Py_IsInitialized()) {
     no_gil.disarm();
     // TODO: call disarm rather than leak gil_scoped_acquired once PyThreadState_Clear can safely be called from finalize
+    // NOTE: deploy.cpp calls `PyInterpreterState_Delete` to destruct PyThreadState, so avoid use-after-free here.
     gil.release();
   }
 #endif
index 53456ca..34e3e38 100644 (file)
@@ -378,3 +378,30 @@ TEST(TorchpyTest, UsesDistributed) {
     I.self.attr("import_module")({"uses_distributed"});
   }
 }
+
+TEST(TorchpyTest, Autograd) {
+  torch::deploy::InterpreterManager m(2);
+  m.register_module_source("autograd_test", R"PYTHON(
+import torch
+
+x = torch.ones(5)  # input tensor
+y = torch.zeros(3)  # expected output
+w = torch.randn(5, 3, requires_grad=True)
+b = torch.randn(3, requires_grad=True)
+z = torch.matmul(x, w)+b
+loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
+loss.backward()
+# result = w.grad
+result = torch.Tensor([1,2,3])
+)PYTHON");
+  at::Tensor w_grad0, w_grad1;
+  {
+    auto I = m.all_instances()[0].acquire_session();
+    w_grad0 = I.global("autograd_test", "result").toIValue().toTensor();
+  }
+  {
+    auto I = m.all_instances()[1].acquire_session();
+    w_grad1 = I.global("autograd_test", "result").toIValue().toTensor();
+  }
+  EXPECT_TRUE(w_grad0.equal(w_grad1));
+}