"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
if session is None and not context.executing_eagerly():
session = ops.get_default_session()
- saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
- sess=session, save_path=self._save_path)
+ with ops.device("/cpu:0"):
+ saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
+ sess=session, save_path=self._save_path)
def initialize_or_restore(self, session=None):
"""Alias for `run_restore_ops`."""
@test_util.run_in_graph_and_eager_modes()
def testLoadFromNameBasedSaver(self):
"""Save a name-based checkpoint, load it using the object-based API."""
- save_path = self._write_name_based_checkpoint()
- root = self._initialized_model()
- self._set_sentinels(root)
- with self.assertRaises(AssertionError):
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status.initialize_or_restore()
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- status = object_saver.restore(save_path)
- with self.assertRaises(AssertionError):
- status.assert_consumed()
- status.run_restore_ops()
- self._check_sentinels(root)
- self._set_sentinels(root)
- status.initialize_or_restore()
- self._check_sentinels(root)
# TODO(allenl): Test for the core name-based saver loading object-based
# checkpoints once object-based checkpointing is in core.