[tf.data] Add usage note to `tf.data.Iterator.get_next()` docstring.
authorDerek Murray <mrry@google.com>
Wed, 14 Feb 2018 21:48:36 +0000 (13:48 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 21:52:38 +0000 (13:52 -0800)
Fixes #16954.

PiperOrigin-RevId: 185738793

tensorflow/python/data/ops/iterator_ops.py

index e573fe0..4756ec7 100644 (file)
@@ -44,8 +44,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = (
     "This often indicates that `Iterator.get_next()` is being called inside "
     "a training loop, which will cause gradual slowdown and eventual resource "
     "exhaustion. If this is the case, restructure your code to call "
-    "`next_element = iterator.get_next() once outside the loop, and use "
-    "`next_element` inside the loop.")
+    "`next_element = iterator.get_next()` once outside the loop, and use "
+    "`next_element` as the input to some computation that is invoked inside "
+    "the loop.")
 
 
 @tf_export("data.Iterator")
@@ -303,7 +304,42 @@ class Iterator(object):
           dataset._as_variant_tensor(), self._iterator_resource, name=name)  # pylint: disable=protected-access
 
   def get_next(self, name=None):
-    """Returns a nested structure of `tf.Tensor`s containing the next element.
+    """Returns a nested structure of `tf.Tensor`s representing the next element.
+
+    In graph mode, you should typically call this method *once* and use its
+    result as the input to another computation. A typical loop will then call
+    @{tf.Session.run} on the result of that computation. The loop will terminate
+    when the `Iterator.get_next()` operation raises
+    @{tf.errors.OutOfRangeError}. The following skeleton shows how to use
+    this method when building a training loop:
+
+    ```python
+    dataset = ...  # A `tf.data.Dataset` object.
+    iterator = dataset.make_initializable_iterator()
+    next_element = iterator.get_next()
+
+    # Build a TensorFlow graph that does something with each element.
+    loss = model_function(next_element)
+    optimizer = ...  # A `tf.train.Optimizer` object.
+    train_op = optimizer.minimize(loss)
+
+    with tf.Session() as sess:
+      try:
+        while True:
+          sess.run(train_op)
+      except tf.errors.OutOfRangeError:
+        pass
+    ```
+
+    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
+    when you are distributing different elements to multiple devices in a single
+    step. However, a common pitfall arises when users call `Iterator.get_next()`
+    in each iteration of their training loop. `Iterator.get_next()` adds ops to
+    the graph, and executing each op allocates resources (including threads); as
+    a consequence, invoking it in every iteration of a training loop causes
+    slowdown and eventual resource exhaustion. To guard against this outcome, we
+    log a warning when the number of uses crosses a fixed threshold of
+    suspiciousness.
 
     Args:
       name: (Optional.) A name for the created operation.