Avoid evaluating SaveSpec Tensors multiple times when executing eagerly
authorAllen Lavoie <allenl@google.com>
Thu, 29 Mar 2018 19:58:43 +0000 (12:58 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 29 Mar 2018 20:00:56 +0000 (13:00 -0700)
The Saver now calls a SaveSpec callable once when saving and not at all when restoring. Previously saving evaluated the callable twice and restoring once (copying a variable's value each time).

Requires a dtype be passed to a SaveSpec if its tensor is callable.

PiperOrigin-RevId: 190972754

tensorflow/python/training/saver.py
tensorflow/python/training/saver_test.py

index cec581d..e40b8d2 100644 (file)
@@ -91,17 +91,27 @@ class BaseSaverBuilder(object):
   class SaveSpec(object):
     """Class used to describe tensor slices that need to be saved."""
 
-    def __init__(self, tensor, slice_spec, name):
+    def __init__(self, tensor, slice_spec, name, dtype=None):
       """Creates a `SaveSpec` object.
 
       Args:
         tensor: the tensor to save or callable that produces a tensor to save.
         slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
         name: the name to save the tensor under.
+        dtype: The data type of the Tensor. Required if `tensor` is callable.
+          Used for error checking in the restore op.
       """
       self._tensor = tensor
       self.slice_spec = slice_spec
       self.name = name
+      if callable(self._tensor):
+        if dtype is None:
+          raise AssertionError(
+              "When passing a callable `tensor` to a SaveSpec, an explicit "
+              "dtype must be provided.")
+        self.dtype = dtype
+      else:
+        self.dtype = tensor.dtype
 
     @property
     def tensor(self):
@@ -117,14 +127,27 @@ class BaseSaverBuilder(object):
         op: the "producer" object that this class wraps; it produces a list of
           tensors to save.  E.g., a "Variable" object saving its backing tensor.
         specs: a list of SaveSpec, each element of which describes one tensor to
-          save under this object.
+          save under this object. All Tensors must be on the same device.
         name: the name to save the object under.
       """
       self.op = op
       self.specs = specs
       self.name = name
-      # The device of this saveable. All tensors must be on the same device.
-      self.device = specs[0].tensor.device
+      self._device = None
+
+    @property
+    def device(self):
+      """The device for SaveSpec Tensors."""
+      # Note that SaveSpec.tensor runs Tensor-gathering ops when executing
+      # eagerly, making this call potentially very expensive.
+      #
+      # TODO(allenl): Consider another way to gather device information. Lower
+      # priority since this property isn't part of the normal save()/restore()
+      # workflow, but does come up when some alternative builders are passed to
+      # the Saver.
+      if self._device is None:
+        self._device = self.specs[0].tensor.device
+      return self._device
 
     def restore(self, restored_tensors, restored_shapes):
       """Restores this object from 'restored_tensors'.
@@ -148,7 +171,7 @@ class BaseSaverBuilder(object):
     """SaveableObject implementation that handles Variables."""
 
     def __init__(self, var, slice_spec, name):
-      spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name)
+      spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype)
       super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
 
     def restore(self, restored_tensors, restored_shapes):
@@ -186,7 +209,8 @@ class BaseSaverBuilder(object):
         raise ValueError(
             "Saveable is neither a resource variable nor a read operation."
             " Got: %s" % repr(var))
-      spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)
+      spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name,
+                                       dtype=var.dtype)
       super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
           var, [spec], name)
 
@@ -295,7 +319,7 @@ class BaseSaverBuilder(object):
               filename_tensor,
               [spec.name],
               [spec.slice_spec],
-              [spec.tensor.dtype])[0])
+              [spec.dtype])[0])
 
     return tensors
   # pylint: enable=unused-argument
@@ -854,7 +878,7 @@ class BulkSaverBuilder(BaseSaverBuilder):
     restore_specs = []
     for saveable in saveables:
       for spec in saveable.specs:
-        restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype))
+        restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
 
     names, slices, dtypes = zip(*restore_specs)
     # Load all tensors onto CPU 0 for compatibility with existing code.
index d1c24b3..14dda79 100644 (file)
@@ -2980,6 +2980,37 @@ class CheckpointableCompatibilityTests(test.TestCase):
       self.assertEqual(42., self.evaluate(v.non_dep_variable))
       self.assertEqual(42., self.evaluate(v.mirrored))
 
+  def testSingleTensorEvaluation(self):
+
+    class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):
+
+      def __init__(self, name):
+        self.eval_count = 0
+        def _tensor():
+          self.eval_count += 1
+          return constant_op.constant([1.])
+        dummy_op = constant_op.constant([2.])
+        super(_CountingSaveable, self).__init__(
+            dummy_op,
+            [saver_module.BaseSaverBuilder.SaveSpec(
+                _tensor, "", name, dtype=dummy_op.dtype)],
+            name)
+
+      def restore(self, restored_tensors, restored_shapes):
+        """Restore the same value into both variables."""
+        pass
+
+    with context.eager_mode():
+      v = _CountingSaveable("foo")
+      saver = saver_module.Saver(var_list=[v])
+      test_dir = self.get_temp_dir()
+      prefix = os.path.join(test_dir, "ckpt")
+      with self.test_session() as sess:
+        save_path = saver.save(sess, prefix)
+        self.assertEqual(1, v.eval_count)
+        saver.restore(sess, save_path)
+        self.assertEqual(1, v.eval_count)
+
 
 if __name__ == "__main__":
   test.main()