from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
return "{}{}".format(prefix, uid)
-class Iterator(object):
- """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+class Iterator(iterator_ops.EagerIterator):
+ """An iterator producing tf.Tensor objects from a tf.data.Dataset.
+
+ NOTE: Unlike the iterator created by the
+ @{tf.data.Dataset.make_one_shot_iterator} method, this class enables
+ additional experimental functionality, such as prefetching to the GPU.
+ """
def __init__(self, dataset):
"""Creates a new iterator over the given dataset.
Raises:
RuntimeError: When invoked without eager execution enabled.
"""
-
- if not context.executing_eagerly():
- raise RuntimeError(
- "{} objects can only be used when eager execution is enabled, use "
- "tf.data.Dataset.make_initializable_iterator or "
- "tf.data.Dataset.make_one_shot_iterator for graph construction".
- format(type(self)))
- with ops.device("/device:CPU:0"):
- ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
- self._output_classes = dataset.output_classes
- self._output_types = dataset.output_types
- self._output_shapes = dataset.output_shapes
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes))
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._resource = gen_dataset_ops.iterator(
- shared_name="",
- container=_generate_shared_name("eageriterator"),
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- gen_dataset_ops.make_iterator(ds_variant, self._resource)
- # Delete the resource when this object is deleted
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device="/device:CPU:0")
- self._device = context.context().device_name
- self._buffer_resource_handle = None
+ super(Iterator, self).__init__(dataset)
if not context.context().device_spec.device_type:
is_remote_device = False
else:
is_remote_device = context.context().device_spec.device_type != "CPU"
+ self._buffer_resource_handle = None
if is_remote_device:
with ops.device("/device:CPU:0"):
iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
@function.Defun(dtypes.string)
def remote_fn(h):
remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, self._output_types, self._output_shapes)
+ h, self.output_types, self.output_shapes, self.output_classes)
return remote_iterator.get_next()
remote_fn.add_to_graph(None)
handle=self._buffer_resource_handle,
handle_device=self._device)
- def __iter__(self):
- return self
-
- def __next__(self): # For Python 3 compatibility
- return self.next()
-
def _next_internal(self):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
"""
- with ops.device(self._device):
- if self._buffer_resource_handle is not None:
+ if self._buffer_resource_handle is not None:
+ with ops.device(self._device):
ret = prefetching_ops.function_buffering_resource_get_next(
function_buffer_resource=self._buffer_resource_handle,
output_types=self._flat_output_types)
- else:
- # TODO(ashankar): Consider removing this ops.device() contextmanager
- # and instead mimic ops placement in graphs: Operations on resource
- # handles execute on the same device as where the resource is placed.
- # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
- # because in eager mode this code will run synchronously on the calling
- # thread. Therefore we do not need to make a defensive context switch
- # to a background thread, and can achieve a small constant performance
- # boost by invoking the iterator synchronously.
- ret = gen_dataset_ops.iterator_get_next_sync(
- self._resource,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-
- def next(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- try:
- return self._next_internal()
- except errors.OutOfRangeError:
- raise StopIteration
-
- @property
- def output_classes(self):
- """Returns the class of each component of an element of this iterator.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of an element of this dataset.
- """
- return self._output_classes
-
- @property
- def output_shapes(self):
- """Returns the shape of each component of an element of this iterator.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of an element of this dataset.
- """
- return self._output_shapes
-
- @property
- def output_types(self):
- """Returns the type of each component of an element of this iterator.
-
- Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of an element of this dataset.
- """
- return self._output_types
-
- def get_next(self, name=None):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
-
- Args:
- name: (Optional.) A name for the created operation. Currently unused.
-
- Returns:
- A nested structure of `tf.Tensor` objects.
-
- Raises:
- `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
- """
- del name
- return self._next_internal()
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+ else:
+ return super(Iterator, self)._next_internal()
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)
+ def testBasicOneShotIterator(self):
+ got = []
+ for t in Dataset.range(4).make_one_shot_iterator():
+ got.append(t.numpy())
+ self.assertAllEqual([0, 1, 2, 3], got)
+
+ def testBasicImplicitIterator(self):
+ got = []
+ for t in Dataset.range(4):
+ got.append(t.numpy())
+ self.assertAllEqual([0, 1, 2, 3], got)
+
def testGetNext(self):
iterator = datasets.Iterator(Dataset.range(4))
self.assertEqual(0, iterator.get_next().numpy())
with self.assertRaises(errors.OutOfRangeError):
iterator.get_next()
+ def testGetNextOneShotIterator(self):
+ iterator = Dataset.range(4).make_one_shot_iterator()
+ self.assertEqual(0, iterator.get_next().numpy())
+ self.assertEqual(1, iterator.get_next().numpy())
+ self.assertEqual(2, iterator.get_next().numpy())
+ self.assertEqual(3, iterator.get_next().numpy())
+ with self.assertRaises(errors.OutOfRangeError):
+ iterator.get_next()
+
def testMultipleIteratorsOnTheSameDataset(self):
ds = Dataset.range(4)
it1 = datasets.Iterator(ds)
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
],
)
self.output_types, self.output_shapes,
self.output_classes)
- def make_one_shot_iterator(self):
+ def __iter__(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
- Note: The returned iterator will be initialized automatically.
- A "one-shot" iterator does not currently support re-initialization.
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
Returns:
An `Iterator` over the elements of this dataset.
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
- raise RuntimeError(
- "dataset.make_one_shot_iterator is not supported when eager "
- "execution is enabled.")
+ return iterator_ops.EagerIterator(self)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+
+ def make_one_shot_iterator(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ Note: The returned iterator will be initialized automatically.
+ A "one-shot" iterator does not currently support re-initialization.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+ """
+ if context.executing_eagerly():
+ return iterator_ops.EagerIterator(self)
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
from __future__ import division
from __future__ import print_function
+import threading
import warnings
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util.tf_export import tf_export
of an element of this dataset.
"""
return self._output_types
+
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+class EagerIterator(object):
+ """An iterator producing tf.Tensor objects from a tf.data.Dataset."""
+
+ def __init__(self, dataset):
+ """Creates a new iterator over the given dataset.
+
+ For example:
+ ```python
+ dataset = tf.data.Dataset.range(4)
+ for x in Iterator(dataset):
+ print(x)
+ ```
+
+ Tensors produced will be placed on the device on which this iterator object
+ was created.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+
+ Raises:
+ RuntimeError: When invoked without eager execution enabled.
+ """
+
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "{} objects can only be used when eager execution is enabled, use "
+ "tf.data.Dataset.make_initializable_iterator or "
+ "tf.data.Dataset.make_one_shot_iterator for graph construction".
+ format(type(self)))
+ with ops.device("/device:CPU:0"):
+ ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
+ self._output_classes = dataset.output_classes
+ self._output_types = dataset.output_types
+ self._output_shapes = dataset.output_shapes
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._resource = gen_dataset_ops.iterator(
+ shared_name="",
+ container=_generate_shared_name("eageriterator"),
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ gen_dataset_ops.make_iterator(ds_variant, self._resource)
+ # Delete the resource when this object is deleted
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device="/device:CPU:0")
+ self._device = context.context().device_name
+
+ def __iter__(self):
+ return self
+
+ def __next__(self): # For Python 3 compatibility
+ return self.next()
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ with ops.device(self._device):
+ # TODO(ashankar): Consider removing this ops.device() contextmanager
+ # and instead mimic ops placement in graphs: Operations on resource
+ # handles execute on the same device as where the resource is placed.
+ # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
+ # because in eager mode this code will run synchronously on the calling
+ # thread. Therefore we do not need to make a defensive context switch
+ # to a background thread, and can achieve a small constant performance
+ # boost by invoking the iterator synchronously.
+ ret = gen_dataset_ops.iterator_get_next_sync(
+ self._resource,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+
+ def next(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ try:
+ return self._next_internal()
+ except errors.OutOfRangeError:
+ raise StopIteration
+
+ @property
+ def output_classes(self):
+ """Returns the class of each component of an element of this iterator.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this dataset.
+ """
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this iterator.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this dataset.
+ """
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ """Returns the type of each component of an element of this iterator.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this dataset.
+ """
+ return self._output_types
+
+ def get_next(self, name=None):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+
+ Args:
+ name: (Optional.) A name for the created operation. Currently unused.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects.
+
+ Raises:
+ `tf.errors.OutOfRangeError`: If the end of the dataset has been reached.
+ """
+ del name
+ return self._next_internal()