Don't call back into python during insert (which will leave the set in a broken condi...
authorAkshay Modi <nareshmodi@google.com>
Mon, 11 Jun 2018 17:42:15 +0000 (10:42 -0700)
committerAkshay Modi <nareshmodi@google.com>
Wed, 13 Jun 2018 19:47:02 +0000 (12:47 -0700)
Thank you for finding the bug. The watched_variables_ set should not really require a lock since all our functions hold the GIL (verified by looking at the generated SWIG). The reason that there was a concurrent access to the set is that the insert was calling back into python (which might release the GIL and let another thread run, which will also attempt to insert a variable and break the set).

I included the lock to be safe though, since its non-trivial to verify without looking at the generated swig wrappers that the GIL is held.

PiperOrigin-RevId: 200074843

tensorflow/python/eager/pywrap_tfe_src.cc

index e3ce0ef..52b3268 100644 (file)
@@ -873,22 +873,6 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
   return static_cast<tensorflow::DataType>(id);
 }
 
-static tensorflow::int64 FastHandleId(PyObject* variable) {
-  PyObject* handle = PyObject_GetAttrString(variable, "handle");
-  if (handle == nullptr) {
-    return -1;
-  }
-  tensorflow::int64 id = FastTensorId(handle);
-  Py_DECREF(handle);
-  return id;
-}
-
-struct CompareByHandleId {
-  bool operator()(PyObject* lhs, PyObject* rhs) {
-    return FastHandleId(lhs) < FastHandleId(rhs);
-  }
-};
-
 class GradientTape
     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
  public:
@@ -897,35 +881,63 @@ class GradientTape
             persistent) {}
 
   virtual ~GradientTape() {
-    for (PyObject* v : watched_variables_) {
-      Py_DECREF(v);
+    for (const IdAndVariable& v : watched_variables_) {
+      Py_DECREF(v.variable);
     }
   }
 
   void WatchVariable(PyObject* v) {
-    auto insert_result = watched_variables_.insert(v);
-    if (insert_result.second) {
-      // Only increment the reference count if we aren't already watching this
-      // variable.
-      Py_INCREF(v);
-    }
-    PyObject* handle = PyObject_GetAttrString(v, "handle");
+    tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
     if (handle == nullptr) {
       return;
     }
-    tensorflow::int64 id = FastTensorId(handle);
-    Py_DECREF(handle);
+    tensorflow::int64 id = FastTensorId(handle.get());
+
     if (!PyErr_Occurred()) {
       this->Watch(id);
     }
+
+    tensorflow::mutex_lock l(watched_variables_mu_);
+    auto insert_result = watched_variables_.emplace(id, v);
+
+    if (insert_result.second) {
+      // Only increment the reference count if we aren't already watching this
+      // variable.
+      Py_INCREF(v);
+    }
   }
 
-  const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
-    return watched_variables_;
+  PyObject* GetVariablesAsPyTuple() {
+    tensorflow::mutex_lock l(watched_variables_mu_);
+    PyObject* result = PyTuple_New(watched_variables_.size());
+    Py_ssize_t pos = 0;
+    for (const IdAndVariable& id_and_variable : watched_variables_) {
+      PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
+      Py_INCREF(id_and_variable.variable);
+    }
+    return result;
   }
 
  private:
-  std::set<PyObject*, CompareByHandleId> watched_variables_;
+  // We store an IdAndVariable in the map since the map needs to be locked
+  // during insert, but should not call back into python during insert to avoid
+  // deadlocking with the GIL.
+  struct IdAndVariable {
+    tensorflow::int64 id;
+    PyObject* variable;
+
+    IdAndVariable(tensorflow::int64 id, PyObject* variable)
+        : id(id), variable(variable) {}
+  };
+  struct CompareById {
+    bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) {
+      return lhs.id < rhs.id;
+    }
+  };
+
+  tensorflow::mutex watched_variables_mu_;
+  std::set<IdAndVariable, CompareById> watched_variables_
+      GUARDED_BY(watched_variables_mu_);
 };
 
 typedef struct {
@@ -1217,15 +1229,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
 }
 
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
-  const auto& watched_variables =
-      reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
-  PyObject* result = PyTuple_New(watched_variables.size());
-  Py_ssize_t pos = 0;
-  for (PyObject* variable : watched_variables) {
-    PyTuple_SET_ITEM(result, pos++, variable);
-    Py_INCREF(variable);
-  }
-  return result;
+  return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
 }
 
 namespace {