From 09a5f58fdc108e084b3d4a3c569a694fa5a96812 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Thu, 17 May 2018 09:57:16 -0700 Subject: [PATCH] Rename private push/pop API and use from `stop_recording` method. PiperOrigin-RevId: 197007561 --- tensorflow/python/eager/backprop.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 773c981..c107d12 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -723,21 +723,21 @@ class GradientTape(object): def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" - self._start_recording() + self._push_tape() return self def __exit__(self, typ, value, traceback): """Exits the recording context, no further operations are traced.""" if self._recording: - self._stop_recording() + self._pop_tape() - def _start_recording(self): + def _push_tape(self): if self._recording: raise ValueError("Tape is already recording.") self._tape = tape.push_new_tape(persistent=self._persistent) self._recording = True - def _stop_recording(self): + def _pop_tape(self): if not self._recording: raise ValueError("Tape is not recording.") tape.pop_tape(self._tape) @@ -778,11 +778,11 @@ class GradientTape(object): if self._tape is None: raise RuntimeError( "Trying to stop recording a tape which is not recording.") - tape.pop_tape(self._tape) + self._pop_tape() try: yield finally: - tape.push_tape(self._tape) + self._push_tape() def reset(self): """Clears all information stored in this tape. @@ -815,8 +815,8 @@ class GradientTape(object): t.reset() ``` """ - self.__exit__(None, None, None) - self.__enter__() + self._pop_tape() + self._push_tape() def watched_variables(self): # Sorting variables by id, which is monotonically increasing in construction @@ -849,7 +849,7 @@ class GradientTape(object): "non-persistent tapes.") if self._recording: if not self._persistent: - self._stop_recording() + self._pop_tape() else: logging.log_first_n(logging.WARN, "Calling GradientTape.gradient on a persistent " -- 2.7.4