Sort variables in C++ instead of Python.
authorTom Hennigan <tomhennigan@google.com>
Mon, 28 May 2018 13:32:04 +0000 (06:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 28 May 2018 13:34:34 +0000 (06:34 -0700)
PiperOrigin-RevId: 198298103

tensorflow/python/eager/backprop.py
tensorflow/python/eager/pywrap_tfe.h
tensorflow/python/eager/pywrap_tfe_src.cc

index 2d859dc..b2e6c60 100644 (file)
@@ -207,16 +207,14 @@ def implicit_val_and_grad(f):
                              f.__name__))
     finally:
       tape.pop_tape(this_tape)
-    # Sorting variables by id, which is monotonically increasing in construction
-    # order. This ensures unique order across executions.
-    # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc.
-    variables = list(sorted(this_tape.watched_variables(),
-                            key=lambda v: v.handle._id))  # pylint: disable=protected-access
-    sources = [x.handle for x in variables]
-
-    if not sources:
+    # Note: variables are returned in construction order. This ensures unique
+    # order across executions.
+    variables = this_tape.watched_variables()
+    if not variables:
       raise ValueError("No trainable variables were accessed while the "
                        "function was being computed.")
+
+    sources = [v.handle for v in variables]
     grad = imperative_grad.imperative_grad(_default_vspace,
                                            this_tape,
                                            nest.flatten(end_node),
@@ -801,11 +799,8 @@ class GradientTape(object):
     self._push_tape()
 
   def watched_variables(self):
-    # Sorting variables by id, which is monotonically increasing in construction
-    # order. This ensures unique order across executions.
-    # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc.
-    return list(sorted(self._tape.watched_variables(),
-                       key=lambda v: v.handle._id))  # pylint: disable=protected-access
+    """Returns variables watched by this tape in order of construction."""
+    return self._tape.watched_variables()
 
   def gradient(self, target, sources, output_gradients=None):
     """Computes the gradient using operations recorded in context of this tape.
index c502fe9..a916a75 100644 (file)
@@ -197,7 +197,8 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
                                 PyObject* attrs, PyObject* results,
                                 PyObject* name);
 
-// Returns the set of variables watched by the given tape.
+// Returns all variables watched by the given tape in the order those variables
+// were created.
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape);
 
 // Returns an EagerTensor of dimension [len(`tensors`)] containing
index 9bbb6f5..52b9050 100644 (file)
@@ -873,6 +873,22 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
   return static_cast<tensorflow::DataType>(id);
 }
 
+static tensorflow::int64 FastHandleId(PyObject* variable) {
+  PyObject* handle = PyObject_GetAttrString(variable, "handle");
+  if (handle == nullptr) {
+    return -1;
+  }
+  tensorflow::int64 id = FastTensorId(handle);
+  Py_DECREF(handle);
+  return id;
+}
+
+struct CompareByHandleId {
+  bool operator()(PyObject* lhs, PyObject* rhs) {
+    return FastHandleId(lhs) < FastHandleId(rhs);
+  }
+};
+
 class GradientTape
     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
  public:
@@ -904,12 +920,12 @@ class GradientTape
     }
   }
 
-  const std::unordered_set<PyObject*> WatchedVariables() {
+  const std::set<PyObject*, CompareByHandleId> WatchedVariables() {
     return watched_variables_;
   }
 
  private:
-  std::unordered_set<PyObject*> watched_variables_;
+  std::set<PyObject*, CompareByHandleId> watched_variables_;
 };
 
 typedef struct {
@@ -1201,11 +1217,13 @@ void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
 }
 
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
-  const std::unordered_set<PyObject*>& watched_variables =
+  const auto& watched_variables =
       reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchedVariables();
-  PyObject* result = PySet_New(nullptr);
+  PyObject* result = PyTuple_New(watched_variables.size());
+  Py_ssize_t pos = 0;
   for (PyObject* variable : watched_variables) {
-    PySet_Add(result, variable);
+    PyTuple_SET_ITEM(result, pos++, variable);
+    Py_INCREF(variable);
   }
   return result;
 }