return static_cast<tensorflow::DataType>(id);
}
-static tensorflow::int64 FastHandleId(PyObject* variable) {
- PyObject* handle = PyObject_GetAttrString(variable, "handle");
- if (handle == nullptr) {
- return -1;
- }
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
- return id;
-}
-
-struct CompareByHandleId {
- bool operator()(PyObject* lhs, PyObject* rhs) {
- return FastHandleId(lhs) < FastHandleId(rhs);
- }
-};
-
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
persistent) {}
virtual ~GradientTape() {
- for (PyObject* v : watched_variables_) {
- Py_DECREF(v);
+ for (const IdAndVariable& v : watched_variables_) {
+ Py_DECREF(v.variable);
}
}
void WatchVariable(PyObject* v) {
- auto insert_result = watched_variables_.insert(v);
- if (insert_result.second) {
- // Only increment the reference count if we aren't already watching this
- // variable.
- Py_INCREF(v);
- }
- PyObject* handle = PyObject_GetAttrString(v, "handle");
+ tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
- tensorflow::int64 id = FastTensorId(handle);
- Py_DECREF(handle);
+ tensorflow::int64 id = FastTensorId(handle.get());
+
if (!PyErr_Occurred()) {
this->Watch(id);
}
+
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ auto insert_result = watched_variables_.emplace(id, v);
+
+ if (insert_result.second) {
+ // Only increment the reference count if we aren't already watching this
+ // variable.
+ Py_INCREF(v);
+ }
}
- const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
- return watched_variables_;
+ PyObject* GetVariablesAsPyTuple() {
+ tensorflow::mutex_lock l(watched_variables_mu_);
+ PyObject* result = PyTuple_New(watched_variables_.size());
+ Py_ssize_t pos = 0;
+ for (const IdAndVariable& id_and_variable : watched_variables_) {
+ PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
+ Py_INCREF(id_and_variable.variable);
+ }
+ return result;
}
private:
- std::set<PyObject*, CompareByHandleId> watched_variables_;
+ // We store an IdAndVariable in the map since the map needs to be locked
+ // during insert, but should not call back into python during insert to avoid
+ // deadlocking with the GIL.
+ struct IdAndVariable {
+ tensorflow::int64 id;
+ PyObject* variable;
+
+ IdAndVariable(tensorflow::int64 id, PyObject* variable)
+ : id(id), variable(variable) {}
+ };
+ struct CompareById {
+ bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) {
+ return lhs.id < rhs.id;
+ }
+ };
+
+ tensorflow::mutex watched_variables_mu_;
+ std::set<IdAndVariable, CompareById> watched_variables_
+ GUARDED_BY(watched_variables_mu_);
};
typedef struct {
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
- const auto& watched_variables =
- reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
- PyObject* result = PyTuple_New(watched_variables.size());
- Py_ssize_t pos = 0;
- for (PyObject* variable : watched_variables) {
- PyTuple_SET_ITEM(result, pos++, variable);
- Py_INCREF(variable);
- }
- return result;
+ return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
namespace {