From 50ba6dd3a182c9578bc10cb2a21d7914a1e7bac1 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 11 Jun 2018 10:42:15 -0700 Subject: [PATCH] Don't call back into python during insert (which will leave the set in a broken condition if the runtime decides to let another thread run). Thank you for finding the bug. The watched_variables_ set should not really require a lock since all our functions hold the GIL (verified by looking at the generated SWIG). The reason that there was a concurrent access to the set is that the insert was calling back into python (which might release the GIL and let another thread run, which will also attempt to insert a variable and break the set). I included the lock to be safe though, since its non-trivial to verify without looking at the generated swig wrappers that the GIL is held. PiperOrigin-RevId: 200074843 --- tensorflow/python/eager/pywrap_tfe_src.cc | 82 ++++++++++++++++--------------- 1 file changed, 43 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index e3ce0ef..52b3268 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -873,22 +873,6 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { return static_cast(id); } -static tensorflow::int64 FastHandleId(PyObject* variable) { - PyObject* handle = PyObject_GetAttrString(variable, "handle"); - if (handle == nullptr) { - return -1; - } - tensorflow::int64 id = FastTensorId(handle); - Py_DECREF(handle); - return id; -} - -struct CompareByHandleId { - bool operator()(PyObject* lhs, PyObject* rhs) { - return FastHandleId(lhs) < FastHandleId(rhs); - } -}; - class GradientTape : public tensorflow::eager::GradientTape { public: @@ -897,35 +881,63 @@ class GradientTape persistent) {} virtual ~GradientTape() { - for (PyObject* v : watched_variables_) { - Py_DECREF(v); + for (const IdAndVariable& v : watched_variables_) { + Py_DECREF(v.variable); } } void WatchVariable(PyObject* v) { - auto insert_result = watched_variables_.insert(v); - if (insert_result.second) { - // Only increment the reference count if we aren't already watching this - // variable. - Py_INCREF(v); - } - PyObject* handle = PyObject_GetAttrString(v, "handle"); + tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); if (handle == nullptr) { return; } - tensorflow::int64 id = FastTensorId(handle); - Py_DECREF(handle); + tensorflow::int64 id = FastTensorId(handle.get()); + if (!PyErr_Occurred()) { this->Watch(id); } + + tensorflow::mutex_lock l(watched_variables_mu_); + auto insert_result = watched_variables_.emplace(id, v); + + if (insert_result.second) { + // Only increment the reference count if we aren't already watching this + // variable. + Py_INCREF(v); + } } - const std::set WatchedVariables() { - return watched_variables_; + PyObject* GetVariablesAsPyTuple() { + tensorflow::mutex_lock l(watched_variables_mu_); + PyObject* result = PyTuple_New(watched_variables_.size()); + Py_ssize_t pos = 0; + for (const IdAndVariable& id_and_variable : watched_variables_) { + PyTuple_SET_ITEM(result, pos++, id_and_variable.variable); + Py_INCREF(id_and_variable.variable); + } + return result; } private: - std::set watched_variables_; + // We store an IdAndVariable in the map since the map needs to be locked + // during insert, but should not call back into python during insert to avoid + // deadlocking with the GIL. + struct IdAndVariable { + tensorflow::int64 id; + PyObject* variable; + + IdAndVariable(tensorflow::int64 id, PyObject* variable) + : id(id), variable(variable) {} + }; + struct CompareById { + bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) { + return lhs.id < rhs.id; + } + }; + + tensorflow::mutex watched_variables_mu_; + std::set watched_variables_ + GUARDED_BY(watched_variables_mu_); }; typedef struct { @@ -1217,15 +1229,7 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) { } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { - const auto& watched_variables = - reinterpret_cast(tape)->tape->WatchedVariables(); - PyObject* result = PyTuple_New(watched_variables.size()); - Py_ssize_t pos = 0; - for (PyObject* variable : watched_variables) { - PyTuple_SET_ITEM(result, pos++, variable); - Py_INCREF(variable); - } - return result; + return reinterpret_cast(tape)->tape->GetVariablesAsPyTuple(); } namespace { -- 2.7.4