[tf.data] Enable Dataset.make_one_shot_iterator() and Dataset.__iter__() in eager...
authorDerek Murray <mrry@google.com>
Mon, 12 Mar 2018 22:40:47 +0000 (15:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 22:44:44 +0000 (15:44 -0700)
This change partially replicates the code in `tf.contrib.eager.Iterator`.
However, since that class depends on contrib-level functionality (viz.
cross-device prefetching support), we cannot move it wholesale to core.

PiperOrigin-RevId: 188790349

tensorflow/contrib/eager/python/datasets.py
tensorflow/contrib/eager/python/datasets_test.py
tensorflow/python/data/ops/BUILD
tensorflow/python/data/ops/dataset_ops.py
tensorflow/python/data/ops/iterator_ops.py

index 30a7642..332bada 100644 (file)
@@ -27,7 +27,6 @@ from tensorflow.python.data.util import sparse
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_dataset_ops
@@ -45,8 +44,13 @@ def _generate_shared_name(prefix):
   return "{}{}".format(prefix, uid)
 
 
-class Iterator(object):
-  """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+class Iterator(iterator_ops.EagerIterator):
+  """An iterator producing tf.Tensor objects from a tf.data.Dataset.
+
+  NOTE: Unlike the iterator created by the
+  @{tf.data.Dataset.make_one_shot_iterator} method, this class enables
+  additional experimental functionality, such as prefetching to the GPU.
+  """
 
   def __init__(self, dataset):
     """Creates a new iterator over the given dataset.
@@ -67,37 +71,12 @@ class Iterator(object):
     Raises:
       RuntimeError: When invoked without eager execution enabled.
     """
-
-    if not context.executing_eagerly():
-      raise RuntimeError(
-          "{} objects can only be used when eager execution is enabled, use "
-          "tf.data.Dataset.make_initializable_iterator or "
-          "tf.data.Dataset.make_one_shot_iterator for graph construction".
-          format(type(self)))
-    with ops.device("/device:CPU:0"):
-      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
-      self._output_classes = dataset.output_classes
-      self._output_types = dataset.output_types
-      self._output_shapes = dataset.output_shapes
-      self._flat_output_types = nest.flatten(
-          sparse.as_dense_types(self._output_types, self._output_classes))
-      self._flat_output_shapes = nest.flatten(
-          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
-      self._resource = gen_dataset_ops.iterator(
-          shared_name="",
-          container=_generate_shared_name("eageriterator"),
-          output_types=self._flat_output_types,
-          output_shapes=self._flat_output_shapes)
-      gen_dataset_ops.make_iterator(ds_variant, self._resource)
-      # Delete the resource when this object is deleted
-      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
-          handle=self._resource, handle_device="/device:CPU:0")
-    self._device = context.context().device_name
-    self._buffer_resource_handle = None
+    super(Iterator, self).__init__(dataset)
     if not context.context().device_spec.device_type:
       is_remote_device = False
     else:
       is_remote_device = context.context().device_spec.device_type != "CPU"
+    self._buffer_resource_handle = None
     if is_remote_device:
       with ops.device("/device:CPU:0"):
         iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
@@ -106,7 +85,7 @@ class Iterator(object):
         @function.Defun(dtypes.string)
         def remote_fn(h):
           remote_iterator = iterator_ops.Iterator.from_string_handle(
-              h, self._output_types, self._output_shapes)
+              h, self.output_types, self.output_shapes, self.output_classes)
           return remote_iterator.get_next()
 
         remote_fn.add_to_graph(None)
@@ -124,89 +103,16 @@ class Iterator(object):
             handle=self._buffer_resource_handle,
             handle_device=self._device)
 
-  def __iter__(self):
-    return self
-
-  def __next__(self):  # For Python 3 compatibility
-    return self.next()
-
   def _next_internal(self):
     """Returns a nested structure of `tf.Tensor`s containing the next element.
     """
-    with ops.device(self._device):
-      if self._buffer_resource_handle is not None:
+    if self._buffer_resource_handle is not None:
+      with ops.device(self._device):
         ret = prefetching_ops.function_buffering_resource_get_next(
             function_buffer_resource=self._buffer_resource_handle,
             output_types=self._flat_output_types)
-      else:
-        # TODO(ashankar): Consider removing this ops.device() contextmanager
-        # and instead mimic ops placement in graphs: Operations on resource
-        # handles execute on the same device as where the resource is placed.
-        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
-        # because in eager mode this code will run synchronously on the calling
-        # thread. Therefore we do not need to make a defensive context switch
-        # to a background thread, and can achieve a small constant performance
-        # boost by invoking the iterator synchronously.
-        ret = gen_dataset_ops.iterator_get_next_sync(
-            self._resource,
-            output_types=self._flat_output_types,
-            output_shapes=self._flat_output_shapes)
-
-    return sparse.deserialize_sparse_tensors(
-        nest.pack_sequence_as(self._output_types, ret), self._output_types,
-        self._output_shapes, self._output_classes)
-
-  def next(self):
-    """Returns a nested structure of `tf.Tensor`s containing the next element.
-    """
-    try:
-      return self._next_internal()
-    except errors.OutOfRangeError:
-      raise StopIteration
-
-  @property
-  def output_classes(self):
-    """Returns the class of each component of an element of this iterator.
-
-    The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
-    Returns:
-      A nested structure of Python `type` objects corresponding to each
-      component of an element of this dataset.
-    """
-    return self._output_classes
-
-  @property
-  def output_shapes(self):
-    """Returns the shape of each component of an element of this iterator.
-
-    Returns:
-      A nested structure of `tf.TensorShape` objects corresponding to each
-      component of an element of this dataset.
-    """
-    return self._output_shapes
-
-  @property
-  def output_types(self):
-    """Returns the type of each component of an element of this iterator.
-
-    Returns:
-      A nested structure of `tf.DType` objects corresponding to each component
-      of an element of this dataset.
-    """
-    return self._output_types
-
-  def get_next(self, name=None):
-    """Returns a nested structure of `tf.Tensor`s containing the next element.
-
-    Args:
-      name: (Optional.) A name for the created operation. Currently unused.
-
-    Returns:
-      A nested structure of `tf.Tensor` objects.
-
-    Raises:
-      `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
-    """
-    del name
-    return self._next_internal()
+      return sparse.deserialize_sparse_tensors(
+          nest.pack_sequence_as(self._output_types, ret), self._output_types,
+          self._output_shapes, self._output_classes)
+    else:
+      return super(Iterator, self)._next_internal()
index 35c3c5d..4afadd8 100644 (file)
@@ -44,6 +44,18 @@ class IteratorTest(test.TestCase):
       got.append(t.numpy())
     self.assertAllEqual([0, 1, 2, 3], got)
 
+  def testBasicOneShotIterator(self):
+    got = []
+    for t in Dataset.range(4).make_one_shot_iterator():
+      got.append(t.numpy())
+    self.assertAllEqual([0, 1, 2, 3], got)
+
+  def testBasicImplicitIterator(self):
+    got = []
+    for t in Dataset.range(4):
+      got.append(t.numpy())
+    self.assertAllEqual([0, 1, 2, 3], got)
+
   def testGetNext(self):
     iterator = datasets.Iterator(Dataset.range(4))
     self.assertEqual(0, iterator.get_next().numpy())
@@ -53,6 +65,15 @@ class IteratorTest(test.TestCase):
     with self.assertRaises(errors.OutOfRangeError):
       iterator.get_next()
 
+  def testGetNextOneShotIterator(self):
+    iterator = Dataset.range(4).make_one_shot_iterator()
+    self.assertEqual(0, iterator.get_next().numpy())
+    self.assertEqual(1, iterator.get_next().numpy())
+    self.assertEqual(2, iterator.get_next().numpy())
+    self.assertEqual(3, iterator.get_next().numpy())
+    with self.assertRaises(errors.OutOfRangeError):
+      iterator.get_next()
+
   def testMultipleIteratorsOnTheSameDataset(self):
     ds = Dataset.range(4)
     it1 = datasets.Iterator(ds)
index a8f2154..3119ab0 100644 (file)
@@ -52,9 +52,11 @@ py_library(
         "//tensorflow/python:dataset_ops_gen",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python/data/util:nest",
         "//tensorflow/python/data/util:sparse",
+        "//tensorflow/python/eager:context",
     ],
 )
 
index e0d63b5..390ce85 100644 (file)
@@ -111,11 +111,11 @@ class Dataset(object):
                                  self.output_types, self.output_shapes,
                                  self.output_classes)
 
-  def make_one_shot_iterator(self):
+  def __iter__(self):
     """Creates an `Iterator` for enumerating the elements of this dataset.
 
-    Note: The returned iterator will be initialized automatically.
-    A "one-shot" iterator does not currently support re-initialization.
+    The returned iterator implements the Python iterator protocol and therefore
+    can only be used in eager mode.
 
     Returns:
       An `Iterator` over the elements of this dataset.
@@ -124,9 +124,22 @@ class Dataset(object):
       RuntimeError: If eager execution is enabled.
     """
     if context.executing_eagerly():
-      raise RuntimeError(
-          "dataset.make_one_shot_iterator is not supported when eager "
-          "execution is enabled.")
+      return iterator_ops.EagerIterator(self)
+    else:
+      raise RuntimeError("dataset.__iter__() is only supported when eager "
+                         "execution is enabled.")
+
+  def make_one_shot_iterator(self):
+    """Creates an `Iterator` for enumerating the elements of this dataset.
+
+    Note: The returned iterator will be initialized automatically.
+    A "one-shot" iterator does not currently support re-initialization.
+
+    Returns:
+      An `Iterator` over the elements of this dataset.
+    """
+    if context.executing_eagerly():
+      return iterator_ops.EagerIterator(self)
     # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
     # a 0-argument function.
     @function.Defun(capture_by_value=True)
index 4756ec7..d79b9d6 100644 (file)
@@ -17,14 +17,18 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import threading
 import warnings
 
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -412,3 +416,147 @@ class Iterator(object):
       of an element of this dataset.
     """
     return self._output_types
+
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+  with _uid_lock:
+    global _uid_counter
+    uid = _uid_counter
+    _uid_counter += 1
+  return "{}{}".format(prefix, uid)
+
+
+class EagerIterator(object):
+  """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+
+  def __init__(self, dataset):
+    """Creates a new iterator over the given dataset.
+
+    For example:
+    ```python
+    dataset = tf.data.Dataset.range(4)
+    for x in Iterator(dataset):
+      print(x)
+    ```
+
+    Tensors produced will be placed on the device on which this iterator object
+    was created.
+
+    Args:
+      dataset: A `tf.data.Dataset` object.
+
+    Raises:
+      RuntimeError: When invoked without eager execution enabled.
+    """
+
+    if not context.executing_eagerly():
+      raise RuntimeError(
+          "{} objects can only be used when eager execution is enabled, use "
+          "tf.data.Dataset.make_initializable_iterator or "
+          "tf.data.Dataset.make_one_shot_iterator for graph construction".
+          format(type(self)))
+    with ops.device("/device:CPU:0"):
+      ds_variant = dataset._as_variant_tensor()  # pylint: disable=protected-access
+      self._output_classes = dataset.output_classes
+      self._output_types = dataset.output_types
+      self._output_shapes = dataset.output_shapes
+      self._flat_output_types = nest.flatten(
+          sparse.as_dense_types(self._output_types, self._output_classes))
+      self._flat_output_shapes = nest.flatten(
+          sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+      self._resource = gen_dataset_ops.iterator(
+          shared_name="",
+          container=_generate_shared_name("eageriterator"),
+          output_types=self._flat_output_types,
+          output_shapes=self._flat_output_shapes)
+      gen_dataset_ops.make_iterator(ds_variant, self._resource)
+      # Delete the resource when this object is deleted
+      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+          handle=self._resource, handle_device="/device:CPU:0")
+    self._device = context.context().device_name
+
+  def __iter__(self):
+    return self
+
+  def __next__(self):  # For Python 3 compatibility
+    return self.next()
+
+  def _next_internal(self):
+    """Returns a nested structure of `tf.Tensor`s containing the next element.
+    """
+    with ops.device(self._device):
+      # TODO(ashankar): Consider removing this ops.device() contextmanager
+      # and instead mimic ops placement in graphs: Operations on resource
+      # handles execute on the same device as where the resource is placed.
+      # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
+      # because in eager mode this code will run synchronously on the calling
+      # thread. Therefore we do not need to make a defensive context switch
+      # to a background thread, and can achieve a small constant performance
+      # boost by invoking the iterator synchronously.
+      ret = gen_dataset_ops.iterator_get_next_sync(
+          self._resource,
+          output_types=self._flat_output_types,
+          output_shapes=self._flat_output_shapes)
+
+    return sparse.deserialize_sparse_tensors(
+        nest.pack_sequence_as(self._output_types, ret), self._output_types,
+        self._output_shapes, self._output_classes)
+
+  def next(self):
+    """Returns a nested structure of `tf.Tensor`s containing the next element.
+    """
+    try:
+      return self._next_internal()
+    except errors.OutOfRangeError:
+      raise StopIteration
+
+  @property
+  def output_classes(self):
+    """Returns the class of each component of an element of this iterator.
+
+    The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+    Returns:
+      A nested structure of Python `type` objects corresponding to each
+      component of an element of this dataset.
+    """
+    return self._output_classes
+
+  @property
+  def output_shapes(self):
+    """Returns the shape of each component of an element of this iterator.
+
+    Returns:
+      A nested structure of `tf.TensorShape` objects corresponding to each
+      component of an element of this dataset.
+    """
+    return self._output_shapes
+
+  @property
+  def output_types(self):
+    """Returns the type of each component of an element of this iterator.
+
+    Returns:
+      A nested structure of `tf.DType` objects corresponding to each component
+      of an element of this dataset.
+    """
+    return self._output_types
+
+  def get_next(self, name=None):
+    """Returns a nested structure of `tf.Tensor`s containing the next element.
+
+    Args:
+      name: (Optional.) A name for the created operation. Currently unused.
+
+    Returns:
+      A nested structure of `tf.Tensor` objects.
+
+    Raises:
+      `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
+    """
+    del name
+    return self._next_internal()