[tf.data] Add optional `args` argument to `Dataset.from_generator()`.
authorDerek Murray <mrry@google.com>
Tue, 15 May 2018 01:04:31 +0000 (18:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 15 May 2018 01:07:16 +0000 (18:07 -0700)
The new argument allows you to parameterize the generator with the value of a tf.Tensor,
enabling `Dataset.from_generator()` to be initialized from a placeholder or used in a
nested expression (such as `flat_map()` or `parallel_interleave()`). For example:

```python
def generator(n):
  for _ in range(n):
    yield n

# Define a generator based on a placeholder.
placeholder = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.from_generator(generator, tf.int64, args=(placeholder,))

# Define a generator based on the value of a nested dataset element.
dataset = tf.data.Dataset.range(10).flat_map(
    lambda i: tf.data.Dataset.from_generator(generator, tf.int64, args=(i,)))
```

Fixes #19269. Partially addresses issue #13101.

PiperOrigin-RevId: 196598650

tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
tensorflow/python/data/ops/dataset_ops.py
tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt

index 6aabad2..9fcdf1b 100644 (file)
@@ -357,6 +357,65 @@ class DatasetConstructorTest(test.TestCase):
       # iterator terminates (and the generator iterator is deleted).
       self.assertTrue(event.is_set())
 
+  def testFromGeneratorWithArgs(self):
+
+    def flat_map_fn(elem):
+
+      def generator_with_arg(n):
+        for _ in range(n):
+          yield np.array(n, dtype=np.int64)
+
+      return dataset_ops.Dataset.from_generator(
+          generator_with_arg, output_types=dtypes.int64, output_shapes=(),
+          args=(elem,))
+
+    iterator = (dataset_ops.Dataset
+                .range(5)
+                .flat_map(flat_map_fn)
+                .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(init_op)
+      expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
+      for x in expected:
+        self.assertEqual(x, sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
+  def testFromGeneratorWithTwoArgs(self):
+
+    def flat_map_fn(elem, message):
+
+      def generator_with_arg(n, msg):
+        for i in range(n):
+          yield i, msg
+
+      return dataset_ops.Dataset.from_generator(
+          generator_with_arg, output_types=(dtypes.int64, dtypes.string),
+          output_shapes=((), ()), args=(elem, message))
+
+    iterator = (
+        dataset_ops.Dataset.zip(
+            (dataset_ops.Dataset.range(5),
+             dataset_ops.Dataset.from_tensors("Hi!").repeat(None)))
+        .flat_map(flat_map_fn)
+        .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(init_op)
+      expected = [(0, b"Hi!"),
+                  (0, b"Hi!"), (1, b"Hi!"),
+                  (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
+                  (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
+      for x in expected:
+        self.assertEqual(x, sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
   def testGeneratorDatasetFinalizeFunctionCalled(self):
     # NOTE(mrry): This test tests the internal `_GeneratorDataset`,
     # which affords more control over what the finalize function can do than
index bd9686f..8b3c2fa 100644 (file)
@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
-import collections
 import threading
 
 import numpy as np
@@ -259,25 +258,32 @@ class Dataset(object):
       self._generator = generator
       self._lock = threading.Lock()
       self._next_id = 0  # GUARDED_BY(self._lock)
-      self._iterators = collections.defaultdict(lambda: iter(generator()))
+      self._args = {}
+      self._iterators = {}
 
-    def get_next_id(self):
+    def get_next_id(self, *args):
       with self._lock:
         ret = self._next_id
         self._next_id += 1
+      self._args[ret] = args
       # NOTE(mrry): Explicitly create an array of `np.int64` because implicit
       # casting in `py_func()` will create an array of `np.int32` on Windows,
       # leading to a runtime error.
       return np.array(ret, dtype=np.int64)
 
     def get_iterator(self, iterator_id):
-      return self._iterators[iterator_id]
+      try:
+        return self._iterators[iterator_id]
+      except KeyError:
+        iterator = iter(self._generator(*self._args.pop(iterator_id)))
+        self._iterators[iterator_id] = iterator
+        return iterator
 
     def iterator_completed(self, iterator_id):
       del self._iterators[iterator_id]
 
   @staticmethod
-  def from_generator(generator, output_types, output_shapes=None):
+  def from_generator(generator, output_types, output_shapes=None, args=None):
     """Creates a `Dataset` whose elements are generated by `generator`.
 
     The `generator` argument must be a callable object that returns
@@ -320,13 +326,17 @@ class Dataset(object):
     `Dataset.from_generator()`.
 
     Args:
-      generator: A callable object that takes no arguments and returns an
-        object that supports the `iter()` protocol.
+      generator: A callable object that returns an object that supports the
+        `iter()` protocol. If `args` is not specified, `generator` must take
+        no arguments; otherwise it must take as many arguments as there are
+        values in `args`.
       output_types: A nested structure of `tf.DType` objects corresponding to
         each component of an element yielded by `generator`.
       output_shapes: (Optional.) A nested structure of `tf.TensorShape`
         objects corresponding to each component of an element yielded by
         `generator`.
+      args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
+        and passed to `generator` as NumPy-array arguments.
 
     Returns:
       Dataset: A `Dataset`.
@@ -339,6 +349,10 @@ class Dataset(object):
     else:
       output_shapes = nest.map_structure_up_to(
           output_types, tensor_shape.as_shape, output_shapes)
+    if args is None:
+      args = ()
+    else:
+      args = tuple(ops.convert_n_to_tensor(args, name="args"))
 
     flattened_types = nest.flatten(output_types)
     flattened_shapes = nest.flatten(output_shapes)
@@ -359,7 +373,7 @@ class Dataset(object):
         `generator_state`.
       """
       return script_ops.py_func(
-          generator_state.get_next_id, [], dtypes.int64, stateful=True)
+          generator_state.get_next_id, args, dtypes.int64, stateful=True)
 
     def generator_next_fn(iterator_id_t):
       """Generates the next element from iterator with ID `iterator_id_t`.
index cbbd077..8e7e945 100644 (file)
@@ -44,7 +44,7 @@ tf_class {
   }
   member_method {
     name: "from_generator"
-    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "from_sparse_tensor_slices"
index 9a56ae8..5cfb2fd 100644 (file)
@@ -45,7 +45,7 @@ tf_class {
   }
   member_method {
     name: "from_generator"
-    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "from_sparse_tensor_slices"
index e5ec824..3327e5b 100644 (file)
@@ -45,7 +45,7 @@ tf_class {
   }
   member_method {
     name: "from_generator"
-    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "from_sparse_tensor_slices"
index 0082397..9d59375 100644 (file)
@@ -45,7 +45,7 @@ tf_class {
   }
   member_method {
     name: "from_generator"
-    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "from_sparse_tensor_slices"