The new argument allows you to parameterize the generator with the value of a tf.Tensor,
enabling `Dataset.from_generator()` to be initialized from a placeholder or used in a
nested expression (such as `flat_map()` or `parallel_interleave()`). For example:
```python
def generator(n):
for _ in range(n):
yield n
# Define a generator based on a placeholder.
placeholder = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.from_generator(generator, tf.int64, args=(placeholder,))
# Define a generator based on the value of a nested dataset element.
dataset = tf.data.Dataset.range(10).flat_map(
lambda i: tf.data.Dataset.from_generator(generator, tf.int64, args=(i,)))
```
Fixes #19269. Partially addresses issue #13101.
PiperOrigin-RevId:
196598650
# iterator terminates (and the generator iterator is deleted).
self.assertTrue(event.is_set())
+ def testFromGeneratorWithArgs(self):
+
+ def flat_map_fn(elem):
+
+ def generator_with_arg(n):
+ for _ in range(n):
+ yield np.array(n, dtype=np.int64)
+
+ return dataset_ops.Dataset.from_generator(
+ generator_with_arg, output_types=dtypes.int64, output_shapes=(),
+ args=(elem,))
+
+ iterator = (dataset_ops.Dataset
+ .range(5)
+ .flat_map(flat_map_fn)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
+ for x in expected:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testFromGeneratorWithTwoArgs(self):
+
+ def flat_map_fn(elem, message):
+
+ def generator_with_arg(n, msg):
+ for i in range(n):
+ yield i, msg
+
+ return dataset_ops.Dataset.from_generator(
+ generator_with_arg, output_types=(dtypes.int64, dtypes.string),
+ output_shapes=((), ()), args=(elem, message))
+
+ iterator = (
+ dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(5),
+ dataset_ops.Dataset.from_tensors("Hi!").repeat(None)))
+ .flat_map(flat_map_fn)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ expected = [(0, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
+ (0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
+ for x in expected:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testGeneratorDatasetFinalizeFunctionCalled(self):
# NOTE(mrry): This test tests the internal `_GeneratorDataset`,
# which affords more control over what the finalize function can do than
from __future__ import print_function
import abc
-import collections
import threading
import numpy as np
self._generator = generator
self._lock = threading.Lock()
self._next_id = 0 # GUARDED_BY(self._lock)
- self._iterators = collections.defaultdict(lambda: iter(generator()))
+ self._args = {}
+ self._iterators = {}
- def get_next_id(self):
+ def get_next_id(self, *args):
with self._lock:
ret = self._next_id
self._next_id += 1
+ self._args[ret] = args
# NOTE(mrry): Explicitly create an array of `np.int64` because implicit
# casting in `py_func()` will create an array of `np.int32` on Windows,
# leading to a runtime error.
return np.array(ret, dtype=np.int64)
def get_iterator(self, iterator_id):
- return self._iterators[iterator_id]
+ try:
+ return self._iterators[iterator_id]
+ except KeyError:
+ iterator = iter(self._generator(*self._args.pop(iterator_id)))
+ self._iterators[iterator_id] = iterator
+ return iterator
def iterator_completed(self, iterator_id):
del self._iterators[iterator_id]
@staticmethod
- def from_generator(generator, output_types, output_shapes=None):
+ def from_generator(generator, output_types, output_shapes=None, args=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
`Dataset.from_generator()`.
Args:
- generator: A callable object that takes no arguments and returns an
- object that supports the `iter()` protocol.
+ generator: A callable object that returns an object that supports the
+ `iter()` protocol. If `args` is not specified, `generator` must take
+ no arguments; otherwise it must take as many arguments as there are
+ values in `args`.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
objects corresponding to each component of an element yielded by
`generator`.
+ args: (Optional.) A tuple of `tf.Tensor` objects that will be evaluated
+ and passed to `generator` as NumPy-array arguments.
Returns:
Dataset: A `Dataset`.
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
+ if args is None:
+ args = ()
+ else:
+ args = tuple(ops.convert_n_to_tensor(args, name="args"))
flattened_types = nest.flatten(output_types)
flattened_shapes = nest.flatten(output_shapes)
`generator_state`.
"""
return script_ops.py_func(
- generator_state.get_next_id, [], dtypes.int64, stateful=True)
+ generator_state.get_next_id, args, dtypes.int64, stateful=True)
def generator_next_fn(iterator_id_t):
"""Generates the next element from iterator with ID `iterator_id_t`.
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"
}
member_method {
name: "from_generator"
- argspec: "args=[\'generator\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "from_sparse_tensor_slices"