From 2d83e131b930581b15a50538a020bda30af08ad4 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Mon, 28 May 2018 06:32:04 -0700 Subject: [PATCH] Sort variables in C++ instead of Python. PiperOrigin-RevId: 198298103 --- tensorflow/python/eager/backprop.py | 21 ++++++++------------- tensorflow/python/eager/pywrap_tfe.h | 3 ++- tensorflow/python/eager/pywrap_tfe_src.cc | 28 +++++++++++++++++++++++----- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 2d859dc..b2e6c60 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -207,16 +207,14 @@ def implicit_val_and_grad(f): f.__name__)) finally: tape.pop_tape(this_tape) - # Sorting variables by id, which is monotonically increasing in construction - # order. This ensures unique order across executions. - # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. - variables = list(sorted(this_tape.watched_variables(), - key=lambda v: v.handle._id)) # pylint: disable=protected-access - sources = [x.handle for x in variables] - - if not sources: + # Note: variables are returned in construction order. This ensures unique + # order across executions. + variables = this_tape.watched_variables() + if not variables: raise ValueError("No trainable variables were accessed while the " "function was being computed.") + + sources = [v.handle for v in variables] grad = imperative_grad.imperative_grad(_default_vspace, this_tape, nest.flatten(end_node), @@ -801,11 +799,8 @@ class GradientTape(object): self._push_tape() def watched_variables(self): - # Sorting variables by id, which is monotonically increasing in construction - # order. This ensures unique order across executions. - # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. - return list(sorted(self._tape.watched_variables(), - key=lambda v: v.handle._id)) # pylint: disable=protected-access + """Returns variables watched by this tape in order of construction.""" + return self._tape.watched_variables() def gradient(self, target, sources, output_gradients=None): """Computes the gradient using operations recorded in context of this tape. diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index c502fe9..a916a75 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -197,7 +197,8 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, PyObject* results, PyObject* name); -// Returns the set of variables watched by the given tape. +// Returns all variables watched by the given tape in the order those variables +// were created. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); // Returns an EagerTensor of dimension [len(`tensors`)] containing diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 9bbb6f5..52b9050 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -873,6 +873,22 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) { return static_cast(id); } +static tensorflow::int64 FastHandleId(PyObject* variable) { + PyObject* handle = PyObject_GetAttrString(variable, "handle"); + if (handle == nullptr) { + return -1; + } + tensorflow::int64 id = FastTensorId(handle); + Py_DECREF(handle); + return id; +} + +struct CompareByHandleId { + bool operator()(PyObject* lhs, PyObject* rhs) { + return FastHandleId(lhs) < FastHandleId(rhs); + } +}; + class GradientTape : public tensorflow::eager::GradientTape { public: @@ -904,12 +920,12 @@ class GradientTape } } - const std::unordered_set WatchedVariables() { + const std::set WatchedVariables() { return watched_variables_; } private: - std::unordered_set watched_variables_; + std::set watched_variables_; }; typedef struct { @@ -1201,11 +1217,13 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) { } PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { - const std::unordered_set& watched_variables = + const auto& watched_variables = reinterpret_cast(tape)->tape->WatchedVariables(); - PyObject* result = PySet_New(nullptr); + PyObject* result = PyTuple_New(watched_variables.size()); + Py_ssize_t pos = 0; for (PyObject* variable : watched_variables) { - PySet_Add(result, variable); + PyTuple_SET_ITEM(result, pos++, variable); + Py_INCREF(variable); } return result; } -- 2.7.4