From 9a75743f7a4190c788a33ec7bd4b384e12292cb1 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 17 May 2018 08:23:10 -0700 Subject: [PATCH] Methods to stop and reset tf.GradientTape() PiperOrigin-RevId: 196995160 --- tensorflow/python/eager/backprop.py | 67 ++++++++++++++++++++++ tensorflow/python/eager/backprop_test.py | 15 +++++ tensorflow/python/eager/pywrap_tfe.h | 3 + tensorflow/python/eager/pywrap_tfe_src.cc | 8 +++ tensorflow/python/eager/tape.py | 5 ++ tensorflow/python/pywrap_tfe.i | 1 + .../api/golden/tensorflow.-gradient-tape.pbtxt | 8 +++ 7 files changed, 107 insertions(+) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 4cdf0a4..773c981 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -39,6 +39,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -751,6 +752,72 @@ class GradientTape(object): for t in nest.flatten(tensor): tape.watch(_handle_or_self(t)) + @tf_contextlib.contextmanager + def stop_recording(self): + """Temporarily stops recording operations on this tape. + + Operations executed while this context manager is active will not be + recorded on the tape. This is useful for reducing the memory used by tracing + all computations. + + For example: + + ``` + with tf.GradientTape(persistent=True) as t: + loss = compute_loss(model) + with t.stop_recording(): + # The gradient computation below is not traced, saving memory. + grads = t.gradient(loss, model.variables) + ``` + + Yields: + None + Raises: + RuntimeError: if the tape is not currently recording. + """ + if self._tape is None: + raise RuntimeError( + "Trying to stop recording a tape which is not recording.") + tape.pop_tape(self._tape) + try: + yield + finally: + tape.push_tape(self._tape) + + def reset(self): + """Clears all information stored in this tape. + + Equivalent to exiting and reentering the tape context manager with a new + tape. For example, the two following code blocks are equivalent: + ``` + with tf.GradientTape() as t: + loss = loss_fn() + with tf.GradientTape() as t: + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + + + # The following is equivalent to the above + with tf.GradientTape() as t: + loss = loss_fn() + t.reset() + loss += other_loss_fn() + t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn + ``` + + This is useful if you don't want to exit the context manager for the tape, + or can't because the desired reset point is inside a control flow construct: + + ``` + with tf.GradientTape() as t: + loss = ... + if loss > k: + t.reset() + ``` + """ + self.__exit__(None, None, None) + self.__enter__() + def watched_variables(self): # Sorting variables by id, which is monotonically increasing in construction # order. This ensures unique order across executions. diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index d4b3c8b..9aaa2e3 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -221,6 +221,21 @@ class BackpropTest(test.TestCase): self.assertTrue(ordered_variables[0] is v0) self.assertTrue(ordered_variables[1] is v1) + def testTapeStopRecording(self): + with backprop.GradientTape() as t: + x = constant_op.constant(1.0) + with t.stop_recording(): + y = x * x + self.assertEqual(t.gradient(y, x), None) + + def testTapeReset(self): + with backprop.GradientTape() as t: + v = resource_variable_ops.ResourceVariable(1.0) + loss = v * v + t.reset() + loss += v * v + self.assertAllEqual(t.gradient(loss, v), 2.0) + @test_util.assert_no_new_tensors def testGradientNone(self): diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 691b613..9bc8b9b 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -120,6 +120,9 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent); // Removes the passed tape from the set of active tapes. void TFE_Py_TapeSetRemove(PyObject* tape); +// Adds the passed tape to the set of active tapes. +void TFE_Py_TapeSetAdd(PyObject* tape); + // Returns true if the tape stack is empty. PyObject* TFE_Py_TapeSetIsEmpty(); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 48a5b21..0f21a91 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1009,6 +1009,14 @@ PyObject* TFE_Py_TapeSetNew(PyObject* persistent) { return reinterpret_cast(tape); } +void TFE_Py_TapeSetAdd(PyObject* tape) { + Py_INCREF(tape); + if (!GetTapeSet()->insert(reinterpret_cast(tape)).second) { + // Already exists in the tape set. + Py_DECREF(tape); + } +} + PyObject* TFE_Py_TapeSetIsEmpty() { if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { Py_RETURN_TRUE; diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index ad82266..caa217b 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -39,6 +39,11 @@ def push_new_tape(persistent=False): return Tape(tape) +def push_tape(tape): + """Pushes an existing tape onto the tape stack.""" + pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access + + def watch(tensor): """Marks this tensor to be watched by all tapes in the stack. diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 5ee5530..fde3223 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -42,6 +42,7 @@ limitations under the License. %rename("%s") TFE_Py_RecordGradient; %rename("%s") TFE_Py_UID; %rename("%s") TFE_Py_TapeSetNew; +%rename("%s") TFE_Py_TapeSetAdd; %rename("%s") TFE_Py_TapeSetRemove; %rename("%s") TFE_Py_TapeSetStopOnThread; %rename("%s") TFE_Py_TapeSetRestartOnThread; diff --git a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt index 7405202..cbf6554 100644 --- a/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-gradient-tape.pbtxt @@ -11,6 +11,14 @@ tf_class { argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { + name: "reset" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "stop_recording" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { name: "watch" argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None" } -- 2.7.4