Methods to stop and reset tf.GradientTape()
authorAlexandre Passos <apassos@google.com>
Thu, 17 May 2018 15:23:10 +0000 (08:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 15:25:40 +0000 (08:25 -0700)
PiperOrigin-RevId: 196995160

tensorflow/python/eager/backprop.py
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/pywrap_tfe.h
tensorflow/python/eager/pywrap_tfe_src.cc
tensorflow/python/eager/tape.py
tensorflow/python/pywrap_tfe.i
tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt

index 4cdf0a4..773c981 100644 (file)
@@ -39,6 +39,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import tf_export
 
@@ -751,6 +752,72 @@ class GradientTape(object):
     for t in nest.flatten(tensor):
       tape.watch(_handle_or_self(t))
 
+  @tf_contextlib.contextmanager
+  def stop_recording(self):
+    """Temporarily stops recording operations on this tape.
+
+    Operations executed while this context manager is active will not be
+    recorded on the tape. This is useful for reducing the memory used by tracing
+    all computations.
+
+    For example:
+
+    ```
+      with tf.GradientTape(persistent=True) as t:
+        loss = compute_loss(model)
+        with t.stop_recording():
+          # The gradient computation below is not traced, saving memory.
+          grads = t.gradient(loss, model.variables)
+    ```
+
+    Yields:
+      None
+    Raises:
+      RuntimeError: if the tape is not currently recording.
+    """
+    if self._tape is None:
+      raise RuntimeError(
+          "Trying to stop recording a tape which is not recording.")
+    tape.pop_tape(self._tape)
+    try:
+      yield
+    finally:
+      tape.push_tape(self._tape)
+
+  def reset(self):
+    """Clears all information stored in this tape.
+
+    Equivalent to exiting and reentering the tape context manager with a new
+    tape. For example, the two following code blocks are equivalent:
+    ```
+    with tf.GradientTape() as t:
+      loss = loss_fn()
+    with tf.GradientTape() as t:
+      loss += other_loss_fn()
+    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
+
+
+    # The following is equivalent to the above
+    with tf.GradientTape() as t:
+      loss = loss_fn()
+      t.reset()
+      loss += other_loss_fn()
+    t.gradient(loss, ...)  # Only differentiates other_loss_fn, not loss_fn
+    ```
+
+    This is useful if you don't want to exit the context manager for the tape,
+    or can't because the desired reset point is inside a control flow construct:
+
+    ```
+    with tf.GradientTape() as t:
+      loss = ...
+      if loss > k:
+        t.reset()
+    ```
+    """
+    self.__exit__(None, None, None)
+    self.__enter__()
+
   def watched_variables(self):
     # Sorting variables by id, which is monotonically increasing in construction
     # order. This ensures unique order across executions.
index d4b3c8b..9aaa2e3 100644 (file)
@@ -221,6 +221,21 @@ class BackpropTest(test.TestCase):
     self.assertTrue(ordered_variables[0] is v0)
     self.assertTrue(ordered_variables[1] is v1)
 
+  def testTapeStopRecording(self):
+    with backprop.GradientTape() as t:
+      x = constant_op.constant(1.0)
+      with t.stop_recording():
+        y = x * x
+    self.assertEqual(t.gradient(y, x), None)
+
+  def testTapeReset(self):
+    with backprop.GradientTape() as t:
+      v = resource_variable_ops.ResourceVariable(1.0)
+      loss = v * v
+      t.reset()
+      loss += v * v
+    self.assertAllEqual(t.gradient(loss, v), 2.0)
+
   @test_util.assert_no_new_tensors
   def testGradientNone(self):
 
index 691b613..9bc8b9b 100644 (file)
@@ -120,6 +120,9 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
 // Removes the passed tape from the set of active tapes.
 void TFE_Py_TapeSetRemove(PyObject* tape);
 
+// Adds the passed tape to the set of active tapes.
+void TFE_Py_TapeSetAdd(PyObject* tape);
+
 // Returns true if the tape stack is empty.
 PyObject* TFE_Py_TapeSetIsEmpty();
 
index 48a5b21..0f21a91 100644 (file)
@@ -1009,6 +1009,14 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
   return reinterpret_cast<PyObject*>(tape);
 }
 
+void TFE_Py_TapeSetAdd(PyObject* tape) {
+  Py_INCREF(tape);
+  if (!GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape)).second) {
+    // Already exists in the tape set.
+    Py_DECREF(tape);
+  }
+}
+
 PyObject* TFE_Py_TapeSetIsEmpty() {
   if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) {
     Py_RETURN_TRUE;
index ad82266..caa217b 100644 (file)
@@ -39,6 +39,11 @@ def push_new_tape(persistent=False):
   return Tape(tape)
 
 
+def push_tape(tape):
+  """Pushes an existing tape onto the tape stack."""
+  pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape)  # pylint: disable=protected-access
+
+
 def watch(tensor):
   """Marks this tensor to be watched by all tapes in the stack.
 
index 5ee5530..fde3223 100644 (file)
@@ -42,6 +42,7 @@ limitations under the License.
 %rename("%s") TFE_Py_RecordGradient;
 %rename("%s") TFE_Py_UID;
 %rename("%s") TFE_Py_TapeSetNew;
+%rename("%s") TFE_Py_TapeSetAdd;
 %rename("%s") TFE_Py_TapeSetRemove;
 %rename("%s") TFE_Py_TapeSetStopOnThread;
 %rename("%s") TFE_Py_TapeSetRestartOnThread;
index 7405202..cbf6554 100644 (file)
@@ -11,6 +11,14 @@ tf_class {
     argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "reset"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "stop_recording"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
     name: "watch"
     argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
   }