[Checkpointable] Make EagerIterator checkpointable.
authorShivani Agrawal <shivaniagrawal@google.com>
Thu, 15 Mar 2018 21:38:25 +0000 (14:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 21:45:49 +0000 (14:45 -0700)
Use object-based save/restore to make dataset/iterator checkpointable in eager mode, this could potentially be extended to graph mode as well.

PiperOrigin-RevId: 189247720

tensorflow/contrib/eager/python/BUILD
tensorflow/contrib/eager/python/datasets.py
tensorflow/contrib/eager/python/datasets_test.py

index 384ef7f..eb810e0 100644 (file)
@@ -70,6 +70,7 @@ cuda_py_test(
     srcs = ["datasets_test.py"],
     additional_deps = [
         ":datasets",
+        ":checkpointable_utils",
         "//tensorflow/contrib/data/python/ops:transformation_ops",
         "//tensorflow/contrib/lookup:lookup_py",
         "//tensorflow/python:dtypes",
index 332bada..a4c3283 100644 (file)
@@ -31,6 +31,8 @@ from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import checkpointable
+from tensorflow.python.training.saver import BaseSaverBuilder
 
 _uid_counter = 0
 _uid_lock = threading.Lock()
@@ -44,7 +46,7 @@ def _generate_shared_name(prefix):
   return "{}{}".format(prefix, uid)
 
 
-class Iterator(iterator_ops.EagerIterator):
+class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
   """An iterator producing tf.Tensor objects from a tf.data.Dataset.
 
   NOTE: Unlike the iterator created by the
@@ -116,3 +118,30 @@ class Iterator(iterator_ops.EagerIterator):
           self._output_shapes, self._output_classes)
     else:
       return super(Iterator, self)._next_internal()
+
+  # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
+  # attributes(potential).
+
+  class _Saveable(BaseSaverBuilder.SaveableObject):
+    """SaveableObject for saving/restoring iterator state."""
+
+    def __init__(self, iterator_resource, name):
+      serialized_iterator = gen_dataset_ops.serialize_iterator(
+          iterator_resource)
+      specs = [
+          BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
+      ]
+      # pylint: disable=protected-access
+      super(Iterator._Saveable, self).__init__(iterator_resource, specs, name)
+
+    def restore(self, restored_tensors, restored_shapes):
+      with ops.colocate_with(self.op):
+        return gen_dataset_ops.deserialize_iterator(self.op,
+                                                    restored_tensors[0])
+
+  def _gather_saveables_for_checkpoint(self):
+
+    def _saveable_factory(name):
+      return self._Saveable(self._resource, name)
+
+    return {"ITERATOR": _saveable_factory}
index 4afadd8..c658505 100644 (file)
@@ -16,6 +16,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
+
 import threading
 import time
 
@@ -24,6 +26,7 @@ import numpy as np
 from tensorflow.contrib import lookup
 from tensorflow.contrib.data.python.ops import threadpool
 from tensorflow.contrib.data.python.ops import unique
+from tensorflow.contrib.eager.python import checkpointable_utils
 from tensorflow.contrib.eager.python import datasets
 from tensorflow.python.data import Dataset
 from tensorflow.python.eager import test
@@ -221,6 +224,61 @@ class IteratorTest(test.TestCase):
       # perform work.
       self.assertLessEqual(len(thread_ids), num_threads)
 
+  def testSaveRestore(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+    dataset = dataset.map(math_ops.square).batch(2)
+    iterator = datasets.Iterator(dataset)
+    checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+    self.assertAllEqual([1, 4], iterator.get_next().numpy())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([9, 16], iterator.get_next().numpy())
+    self.assertAllEqual([25, 36], iterator.get_next().numpy())
+    checkpoint.restore(save_path)
+    self.assertAllEqual([9, 16], iterator.get_next().numpy())
+    self.assertAllEqual([25, 36], iterator.get_next().numpy())
+
+  def testSaveRestoreMultipleIterator(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+    dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+    dataset = dataset.map(math_ops.square).batch(2)
+    iterator_1 = datasets.Iterator(dataset)
+    iterator_2 = datasets.Iterator(dataset)
+    dataset_2 = Dataset.range(10)
+    iterator_3 = datasets.Iterator(dataset_2)
+
+    checkpoint = checkpointable_utils.Checkpoint(
+        iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+    self.assertAllEqual([1, 4], iterator_1.get_next().numpy())
+    self.assertEqual(0, iterator_3.get_next().numpy())
+    self.assertEqual(1, iterator_3.get_next().numpy())
+    self.assertEqual(2, iterator_3.get_next().numpy())
+
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
+    self.assertAllEqual([9, 16], iterator_2.get_next().numpy())
+    self.assertEqual(3, iterator_3.get_next().numpy())
+    checkpoint.restore(save_path)
+    self.assertAllEqual([9, 16], iterator_1.get_next().numpy())
+    self.assertAllEqual([1, 4], iterator_2.get_next().numpy())
+    self.assertEqual(3, iterator_3.get_next().numpy())
+
+  def testRestoreExhaustedIterator(self):
+    checkpoint_directory = self.get_temp_dir()
+    checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
+    dataset = Dataset.range(3)
+    iterator = datasets.Iterator(dataset)
+
+    checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+    self.assertEqual(0, iterator.get_next().numpy())
+    self.assertEqual(1, iterator.get_next().numpy())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertEqual(2, iterator.get_next().numpy())
+    checkpoint.restore(save_path)
+    self.assertEqual(2, iterator.get_next().numpy())
+
 
 class DatasetConstructorBenchmark(test.Benchmark):