Rename private push/pop API and use from `stop_recording` method.
authorTom Hennigan <tomhennigan@google.com>
Thu, 17 May 2018 16:57:16 +0000 (09:57 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 17:00:07 +0000 (10:00 -0700)
PiperOrigin-RevId: 197007561

tensorflow/python/eager/backprop.py

index 773c981..c107d12 100644 (file)
@@ -723,21 +723,21 @@ class GradientTape(object):
 
   def __enter__(self):
     """Enters a context inside which operations are recorded on this tape."""
-    self._start_recording()
+    self._push_tape()
     return self
 
   def __exit__(self, typ, value, traceback):
     """Exits the recording context, no further operations are traced."""
     if self._recording:
-      self._stop_recording()
+      self._pop_tape()
 
-  def _start_recording(self):
+  def _push_tape(self):
     if self._recording:
       raise ValueError("Tape is already recording.")
     self._tape = tape.push_new_tape(persistent=self._persistent)
     self._recording = True
 
-  def _stop_recording(self):
+  def _pop_tape(self):
     if not self._recording:
       raise ValueError("Tape is not recording.")
     tape.pop_tape(self._tape)
@@ -778,11 +778,11 @@ class GradientTape(object):
     if self._tape is None:
       raise RuntimeError(
           "Trying to stop recording a tape which is not recording.")
-    tape.pop_tape(self._tape)
+    self._pop_tape()
     try:
       yield
     finally:
-      tape.push_tape(self._tape)
+      self._push_tape()
 
   def reset(self):
     """Clears all information stored in this tape.
@@ -815,8 +815,8 @@ class GradientTape(object):
         t.reset()
     ```
     """
-    self.__exit__(None, None, None)
-    self.__enter__()
+    self._pop_tape()
+    self._push_tape()
 
   def watched_variables(self):
     # Sorting variables by id, which is monotonically increasing in construction
@@ -849,7 +849,7 @@ class GradientTape(object):
                          "non-persistent tapes.")
     if self._recording:
       if not self._persistent:
-        self._stop_recording()
+        self._pop_tape()
       else:
         logging.log_first_n(logging.WARN,
                             "Calling GradientTape.gradient on a persistent "