class SaveSpec(object):
"""Class used to describe tensor slices that need to be saved."""
- def __init__(self, tensor, slice_spec, name):
+ def __init__(self, tensor, slice_spec, name, dtype=None):
"""Creates a `SaveSpec` object.
Args:
tensor: the tensor to save or callable that produces a tensor to save.
slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
name: the name to save the tensor under.
+ dtype: The data type of the Tensor. Required if `tensor` is callable.
+ Used for error checking in the restore op.
"""
self._tensor = tensor
self.slice_spec = slice_spec
self.name = name
+ if callable(self._tensor):
+ if dtype is None:
+ raise AssertionError(
+ "When passing a callable `tensor` to a SaveSpec, an explicit "
+ "dtype must be provided.")
+ self.dtype = dtype
+ else:
+ self.dtype = tensor.dtype
@property
def tensor(self):
op: the "producer" object that this class wraps; it produces a list of
tensors to save. E.g., a "Variable" object saving its backing tensor.
specs: a list of SaveSpec, each element of which describes one tensor to
- save under this object.
+ save under this object. All Tensors must be on the same device.
name: the name to save the object under.
"""
self.op = op
self.specs = specs
self.name = name
- # The device of this saveable. All tensors must be on the same device.
- self.device = specs[0].tensor.device
+ self._device = None
+
+ @property
+ def device(self):
+ """The device for SaveSpec Tensors."""
+ # Note that SaveSpec.tensor runs Tensor-gathering ops when executing
+ # eagerly, making this call potentially very expensive.
+ #
+ # TODO(allenl): Consider another way to gather device information. Lower
+ # priority since this property isn't part of the normal save()/restore()
+ # workflow, but does come up when some alternative builders are passed to
+ # the Saver.
+ if self._device is None:
+ self._device = self.specs[0].tensor.device
+ return self._device
def restore(self, restored_tensors, restored_shapes):
"""Restores this object from 'restored_tensors'.
"""SaveableObject implementation that handles Variables."""
def __init__(self, var, slice_spec, name):
- spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name)
+ spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name, dtype=var.dtype)
super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
raise ValueError(
"Saveable is neither a resource variable nor a read operation."
" Got: %s" % repr(var))
- spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)
+ spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name,
+ dtype=var.dtype)
super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
var, [spec], name)
filename_tensor,
[spec.name],
[spec.slice_spec],
- [spec.tensor.dtype])[0])
+ [spec.dtype])[0])
return tensors
# pylint: enable=unused-argument
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
- restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype))
+ restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
# Load all tensors onto CPU 0 for compatibility with existing code.
self.assertEqual(42., self.evaluate(v.non_dep_variable))
self.assertEqual(42., self.evaluate(v.mirrored))
+ def testSingleTensorEvaluation(self):
+
+ class _CountingSaveable(saver_module.BaseSaverBuilder.SaveableObject):
+
+ def __init__(self, name):
+ self.eval_count = 0
+ def _tensor():
+ self.eval_count += 1
+ return constant_op.constant([1.])
+ dummy_op = constant_op.constant([2.])
+ super(_CountingSaveable, self).__init__(
+ dummy_op,
+ [saver_module.BaseSaverBuilder.SaveSpec(
+ _tensor, "", name, dtype=dummy_op.dtype)],
+ name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ pass
+
+ with context.eager_mode():
+ v = _CountingSaveable("foo")
+ saver = saver_module.Saver(var_list=[v])
+ test_dir = self.get_temp_dir()
+ prefix = os.path.join(test_dir, "ckpt")
+ with self.test_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.assertEqual(1, v.eval_count)
+ saver.restore(sess, save_path)
+ self.assertEqual(1, v.eval_count)
+
if __name__ == "__main__":
test.main()