"""
named_variables, graph_proto = _serialize_object_graph(
self._root_checkpointable)
- in_graph_mode = not context.executing_eagerly()
- if in_graph_mode:
+ if not context.executing_eagerly():
if session is None:
session = ops.get_default_session()
if self._object_graph_feed_tensor is None:
named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
tensor=object_graph_tensor,
name=_OBJECT_GRAPH_PROTO_KEY)
- if not in_graph_mode or self._last_save_object_graph != graph_proto:
- if self._last_save_object_graph is not None and in_graph_mode:
+ if self._last_save_object_graph != graph_proto:
+ if self._last_save_object_graph is not None:
raise NotImplementedError(
"Using a single Saver to save a mutated object graph is not "
"currently supported when graph building. Use a different Saver "
- "when the object graph changes (save ops will be duplicated), or "
- "file a feature request if this limitation bothers you.")
+ "when the object graph changes (save ops will be duplicated when "
+ "graph building), or file a feature request if this limitation "
+ "bothers you.")
saver = saver_lib.Saver(var_list=named_variables)
- if in_graph_mode:
- self._last_save_saver = saver
- self._last_save_object_graph = graph_proto
+ self._last_save_saver = saver
+ self._last_save_object_graph = graph_proto
else:
saver = self._last_save_saver
with ops.device("/cpu:0"):
saver.save(checkpoint_prefix)
self.assertEqual(before_ops, graph.get_operations())
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testCheckpointCleanup(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = checkpointable.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ self.evaluate(checkpointable_utils.gather_initializers(obj))
+ saver = checkpointable_utils.Checkpoint(obj=obj)
+ for _ in range(10):
+ saver.save(checkpoint_prefix)
+ expected_filenames = ["checkpoint"]
+ for checkpoint_number in range(6, 11):
+ expected_filenames.append("ckpt-%d.index" % (checkpoint_number,))
+ expected_filenames.append(
+ "ckpt-%d.data-00000-of-00001" % (checkpoint_number,))
+ six.assertCountEqual(
+ self,
+ expected_filenames,
+ os.listdir(checkpoint_directory))
+
def testManyRestoresGraph(self):
"""Restores after the first should not modify the graph."""
with context.graph_mode():