from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
for t in nest.flatten(tensor):
tape.watch(_handle_or_self(t))
+ @tf_contextlib.contextmanager
+ def stop_recording(self):
+ """Temporarily stops recording operations on this tape.
+
+ Operations executed while this context manager is active will not be
+ recorded on the tape. This is useful for reducing the memory used by tracing
+ all computations.
+
+ For example:
+
+ ```
+ with tf.GradientTape(persistent=True) as t:
+ loss = compute_loss(model)
+ with t.stop_recording():
+ # The gradient computation below is not traced, saving memory.
+ grads = t.gradient(loss, model.variables)
+ ```
+
+ Yields:
+ None
+ Raises:
+ RuntimeError: if the tape is not currently recording.
+ """
+ if self._tape is None:
+ raise RuntimeError(
+ "Trying to stop recording a tape which is not recording.")
+ tape.pop_tape(self._tape)
+ try:
+ yield
+ finally:
+ tape.push_tape(self._tape)
+
+ def reset(self):
+ """Clears all information stored in this tape.
+
+ Equivalent to exiting and reentering the tape context manager with a new
+ tape. For example, the two following code blocks are equivalent:
+ ```
+ with tf.GradientTape() as t:
+ loss = loss_fn()
+ with tf.GradientTape() as t:
+ loss += other_loss_fn()
+ t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
+
+
+ # The following is equivalent to the above
+ with tf.GradientTape() as t:
+ loss = loss_fn()
+ t.reset()
+ loss += other_loss_fn()
+ t.gradient(loss, ...) # Only differentiates other_loss_fn, not loss_fn
+ ```
+
+ This is useful if you don't want to exit the context manager for the tape,
+ or can't because the desired reset point is inside a control flow construct:
+
+ ```
+ with tf.GradientTape() as t:
+ loss = ...
+ if loss > k:
+ t.reset()
+ ```
+ """
+ self.__exit__(None, None, None)
+ self.__enter__()
+
def watched_variables(self):
# Sorting variables by id, which is monotonically increasing in construction
# order. This ensures unique order across executions.
self.assertTrue(ordered_variables[0] is v0)
self.assertTrue(ordered_variables[1] is v1)
+ def testTapeStopRecording(self):
+ with backprop.GradientTape() as t:
+ x = constant_op.constant(1.0)
+ with t.stop_recording():
+ y = x * x
+ self.assertEqual(t.gradient(y, x), None)
+
+ def testTapeReset(self):
+ with backprop.GradientTape() as t:
+ v = resource_variable_ops.ResourceVariable(1.0)
+ loss = v * v
+ t.reset()
+ loss += v * v
+ self.assertAllEqual(t.gradient(loss, v), 2.0)
+
@test_util.assert_no_new_tensors
def testGradientNone(self):
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "reset"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "stop_recording"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "watch"
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
}