Checkpointable: Fix device placement when restoring name-based checkpoints.
authorAllen Lavoie <allenl@google.com>
Wed, 7 Mar 2018 23:52:25 +0000 (15:52 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 23:56:33 +0000 (15:56 -0800)
Just need to put the restore ops on a CPU.

PiperOrigin-RevId: 188248198

tensorflow/contrib/eager/python/checkpointable_utils.py
tensorflow/contrib/eager/python/checkpointable_utils_test.py

index 1fa150f..d07121d 100644 (file)
@@ -493,8 +493,9 @@ class NameBasedSaverStatus(_LoadStatus):
     """Load the name-based training checkpoint using a new `tf.train.Saver`."""
     if session is None and not context.executing_eagerly():
       session = ops.get_default_session()
-    saver_lib.Saver(self._object_saver._global_variable_names()).restore(  # pylint: disable=protected-access
-        sess=session, save_path=self._save_path)
+    with ops.device("/cpu:0"):
+      saver_lib.Saver(self._object_saver._global_variable_names()).restore(  # pylint: disable=protected-access
+          sess=session, save_path=self._save_path)
 
   def initialize_or_restore(self, session=None):
     """Alias for `run_restore_ops`."""
index fd9fc09..2054878 100644 (file)
@@ -993,20 +993,21 @@ class CheckpointCompatibilityTests(test.TestCase):
   @test_util.run_in_graph_and_eager_modes()
   def testLoadFromNameBasedSaver(self):
     """Save a name-based checkpoint, load it using the object-based API."""
-    save_path = self._write_name_based_checkpoint()
-    root = self._initialized_model()
-    self._set_sentinels(root)
-    with self.assertRaises(AssertionError):
+    with test_util.device(use_gpu=True):
+      save_path = self._write_name_based_checkpoint()
+      root = self._initialized_model()
+      self._set_sentinels(root)
+      with self.assertRaises(AssertionError):
+        self._check_sentinels(root)
+      object_saver = checkpointable_utils.CheckpointableSaver(root)
+      status = object_saver.restore(save_path)
+      with self.assertRaises(AssertionError):
+        status.assert_consumed()
+      status.run_restore_ops()
+      self._check_sentinels(root)
+      self._set_sentinels(root)
+      status.initialize_or_restore()
       self._check_sentinels(root)
-    object_saver = checkpointable_utils.CheckpointableSaver(root)
-    status = object_saver.restore(save_path)
-    with self.assertRaises(AssertionError):
-      status.assert_consumed()
-    status.run_restore_ops()
-    self._check_sentinels(root)
-    self._set_sentinels(root)
-    status.initialize_or_restore()
-    self._check_sentinels(root)
 
   # TODO(allenl): Test for the core name-based saver loading object-based
   # checkpoints once object-based checkpointing is in core.