dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
with self.test_session() as sess:
itr = dataset.make_one_shot_iterator()
+ next_element = itr.get_next()
with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
+ sess.run(next_element)
def testSimpleDirectory(self):
filenames = ['a', 'b', 'c']
dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
with self.test_session() as sess:
itr = dataset.make_one_shot_iterator()
+ next_element = itr.get_next()
full_filenames = []
produced_filenames = []
for filename in filenames:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+ produced_filenames.append(compat.as_bytes(sess.run(next_element)))
self.assertItemsEqual(full_filenames, produced_filenames)
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
+ next_element = itr.get_next()
sess.run(
itr.initializer,
feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
+ sess.run(next_element)
def testSimpleDirectoryInitializer(self):
filenames = ['a', 'b', 'c']
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
+ next_element = itr.get_next()
sess.run(
itr.initializer,
feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
for filename in filenames:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+ produced_filenames.append(compat.as_bytes(sess.run(next_element)))
self.assertItemsEqual(full_filenames, produced_filenames)
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
+ next_element = itr.get_next()
sess.run(
itr.initializer,
feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')})
for filename in filenames[1:-1]:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+ produced_filenames.append(compat.as_bytes(sess.run(next_element)))
self.assertItemsEqual(full_filenames, produced_filenames)
with self.assertRaises(errors.OutOfRangeError):
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
+ next_element = itr.get_next()
sess.run(
itr.initializer,
feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})
for filename in filenames[1:]:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+ produced_filenames.append(compat.as_bytes(sess.run(next_element)))
self.assertItemsEqual(full_filenames, produced_filenames)
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
+ def testNoShuffle(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ # Repeat the list twice and ensure that the order is the same each time.
+ # NOTE(mrry): This depends on an implementation detail of `list_files()`,
+ # which is that the list of files is captured when the iterator is
+ # initialized. Otherwise, or if e.g. the iterator were initialized more than
+ # once, it's possible that the non-determinism of `tf.matching_files()`
+ # would cause this test to fail. However, it serves as a useful confirmation
+ # that the `shuffle=False` argument is working as intended.
+ # TODO(b/73959787): Provide some ordering guarantees so that this test is
+ # more meaningful.
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
+ with self.test_session() as sess:
+ itr = dataset.make_one_shot_iterator()
+ next_element = itr.get_next()
+
+ full_filenames = []
+ produced_filenames = []
+ for filename in filenames * 2:
+ full_filenames.append(
+ compat.as_bytes(path.join(self.tmp_dir, filename)))
+ produced_filenames.append(compat.as_bytes(sess.run(next_element)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(itr.get_next())
+ self.assertItemsEqual(full_filenames, produced_filenames)
+ self.assertEqual(produced_filenames[:len(filenames)],
+ produced_filenames[len(filenames):])
+
if __name__ == '__main__':
test.main()
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
return PrefetchDataset(self, buffer_size)
@staticmethod
- def list_files(file_pattern):
+ def list_files(file_pattern, shuffle=None):
"""A dataset of all files matching a pattern.
Example:
- /path/to/dir/b.py
- /path/to/dir/c.py
- NOTE: The order of the file names returned can be non-deterministic.
+ NOTE: The order of the file names returned can be non-deterministic even
+ when `shuffle` is `False`.
Args:
file_pattern: A string or scalar string `tf.Tensor`, representing
the filename pattern that will be matched.
+ shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
+ Defaults to `True`.
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
- return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
+ # TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
+ # behavior deterministic (e.g. by sorting the filenames).
+ if shuffle is None:
+ shuffle = True
+ matching_files = gen_io_ops.matching_files(file_pattern)
+ dataset = Dataset.from_tensor_slices(matching_files)
+ if shuffle:
+ # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
+ # list of files might be empty.
+ buffer_size = math_ops.maximum(
+ array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
+ dataset = dataset.shuffle(buffer_size)
+ return dataset
def repeat(self, count=None):
"""Repeats this dataset `count` times.