Really delete old checkpoints this time.
authorAllen Lavoie <allenl@google.com>
Thu, 15 Mar 2018 17:54:42 +0000 (10:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 17:59:13 +0000 (10:59 -0700)
Follows up on cl/188187349, which fixed checkpoint management for tf.train.Saver
when executing eagerly. Except I was recreating the tf.train.Saver objects each
save, so tfe.Checkpoint and friends did not benefit from that change.

Keeps the same tf.train.Saver around when executing eagerly. This limits object
graph mutations just like when graph building; if there are complaints I can
assign to Saver._var_list instead, since eager tf.train.Saver is not specialized
to its var_list argument.

PiperOrigin-RevId: 189211552

tensorflow/contrib/eager/python/checkpointable_utils.py
tensorflow/contrib/eager/python/checkpointable_utils_test.py

index 677b56b..389d4a0 100644 (file)
@@ -602,8 +602,7 @@ class CheckpointableSaver(object):
     """
     named_variables, graph_proto = _serialize_object_graph(
         self._root_checkpointable)
-    in_graph_mode = not context.executing_eagerly()
-    if in_graph_mode:
+    if not context.executing_eagerly():
       if session is None:
         session = ops.get_default_session()
       if self._object_graph_feed_tensor is None:
@@ -622,17 +621,17 @@ class CheckpointableSaver(object):
     named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
         tensor=object_graph_tensor,
         name=_OBJECT_GRAPH_PROTO_KEY)
-    if not in_graph_mode or self._last_save_object_graph != graph_proto:
-      if self._last_save_object_graph is not None and in_graph_mode:
+    if self._last_save_object_graph != graph_proto:
+      if self._last_save_object_graph is not None:
         raise NotImplementedError(
             "Using a single Saver to save a mutated object graph is not "
             "currently supported when graph building. Use a different Saver "
-            "when the object graph changes (save ops will be duplicated), or "
-            "file a feature request if this limitation bothers you.")
+            "when the object graph changes (save ops will be duplicated when "
+            "graph building), or file a feature request if this limitation "
+            "bothers you.")
       saver = saver_lib.Saver(var_list=named_variables)
-      if in_graph_mode:
-        self._last_save_saver = saver
-        self._last_save_object_graph = graph_proto
+      self._last_save_saver = saver
+      self._last_save_object_graph = graph_proto
     else:
       saver = self._last_save_saver
     with ops.device("/cpu:0"):
index 31f661e..1ab94b8 100644 (file)
@@ -849,6 +849,26 @@ class CheckpointingTests(test.TestCase):
         saver.save(checkpoint_prefix)
         self.assertEqual(before_ops, graph.get_operations())
 
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+  def testCheckpointCleanup(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+    obj = checkpointable.Checkpointable()
+    obj.var = variable_scope.get_variable(name="v", initializer=0.)
+    self.evaluate(checkpointable_utils.gather_initializers(obj))
+    saver = checkpointable_utils.Checkpoint(obj=obj)
+    for _ in range(10):
+      saver.save(checkpoint_prefix)
+    expected_filenames = ["checkpoint"]
+    for checkpoint_number in range(6, 11):
+      expected_filenames.append("ckpt-%d.index" % (checkpoint_number,))
+      expected_filenames.append(
+          "ckpt-%d.data-00000-of-00001" % (checkpoint_number,))
+    six.assertCountEqual(
+        self,
+        expected_filenames,
+        os.listdir(checkpoint_directory))
+
   def testManyRestoresGraph(self):
     """Restores after the first should not modify the graph."""
     with context.graph_mode():