from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training.saver import BaseSaverBuilder
_uid_counter = 0
_uid_lock = threading.Lock()
return "{}{}".format(prefix, uid)
-class Iterator(iterator_ops.EagerIterator):
+class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
NOTE: Unlike the iterator created by the
self._output_shapes, self._output_classes)
else:
return super(Iterator, self)._next_internal()
+
+ # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
+ # attributes(potential).
+
+ class _Saveable(BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource, name):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ specs = [
+ BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
+ ]
+ # pylint: disable=protected-access
+ super(Iterator._Saveable, self).__init__(iterator_resource, specs, name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op,
+ restored_tensors[0])
+
+ def _gather_saveables_for_checkpoint(self):
+
+ def _saveable_factory(name):
+ return self._Saveable(self._resource, name)
+
+ return {"ITERATOR": _saveable_factory}
from __future__ import division
from __future__ import print_function
+import os
+
import threading
import time
from tensorflow.contrib import lookup
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.data import Dataset
from tensorflow.python.eager import test
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
+ def testSaveRestore(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ dataset = dataset.map(math_ops.square).batch(2)
+ iterator = datasets.Iterator(dataset)
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ self.assertAllEqual([1, 4], iterator.get_next().numpy())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], iterator.get_next().numpy())
+ self.assertAllEqual([25, 36], iterator.get_next().numpy())
+ checkpoint.restore(save_path)
+ self.assertAllEqual([9, 16], iterator.get_next().numpy())
+ self.assertAllEqual([25, 36], iterator.get_next().numpy())
+
+ def testSaveRestoreMultipleIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ dataset = dataset.map(math_ops.square).batch(2)
+ iterator_1 = datasets.Iterator(dataset)
+ iterator_2 = datasets.Iterator(dataset)
+ dataset_2 = Dataset.range(10)
+ iterator_3 = datasets.Iterator(dataset_2)
+
+ checkpoint = checkpointable_utils.Checkpoint(
+ iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+ self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
+ self.assertEqual(0, iterator_3.get_next().numpy())
+ self.assertEqual(1, iterator_3.get_next().numpy())
+ self.assertEqual(2, iterator_3.get_next().numpy())
+
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
+ self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
+ self.assertEqual(3, iterator_3.get_next().numpy())
+ checkpoint.restore(save_path)
+ self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
+ self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
+ self.assertEqual(3, iterator_3.get_next().numpy())
+
+ def testRestoreExhaustedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+ dataset = Dataset.range(3)
+ iterator = datasets.Iterator(dataset)
+
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ self.assertEqual(0, iterator.get_next().numpy())
+ self.assertEqual(1, iterator.get_next().numpy())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertEqual(2, iterator.get_next().numpy())
+ checkpoint.restore(save_path)
+ self.assertEqual(2, iterator.get_next().numpy())
+
class DatasetConstructorBenchmark(test.Benchmark):