[tf.data] Add optional `shuffle` argument to `Dataset.list_files()`.
authorDerek Murray <mrry@google.com>
Thu, 1 Mar 2018 06:58:19 +0000 (22:58 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 1 Mar 2018 07:02:40 +0000 (23:02 -0800)
This option makes it easier to shuffle a set of filenames on each iteration,
and default to true to match the recommended best practices when training on
a large dataset.

PiperOrigin-RevId: 187434282

tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
tensorflow/python/data/ops/dataset_ops.py
tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt

index 4e7691e..6442eb9 100644 (file)
@@ -46,8 +46,9 @@ class ListFilesDatasetOpTest(test.TestCase):
     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
     with self.test_session() as sess:
       itr = dataset.make_one_shot_iterator()
+      next_element = itr.get_next()
       with self.assertRaises(errors.OutOfRangeError):
-        sess.run(itr.get_next())
+        sess.run(next_element)
 
   def testSimpleDirectory(self):
     filenames = ['a', 'b', 'c']
@@ -56,13 +57,14 @@ class ListFilesDatasetOpTest(test.TestCase):
     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
     with self.test_session() as sess:
       itr = dataset.make_one_shot_iterator()
+      next_element = itr.get_next()
 
       full_filenames = []
       produced_filenames = []
       for filename in filenames:
         full_filenames.append(
             compat.as_bytes(path.join(self.tmp_dir, filename)))
-        produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
       self.assertItemsEqual(full_filenames, produced_filenames)
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(itr.get_next())
@@ -73,12 +75,13 @@ class ListFilesDatasetOpTest(test.TestCase):
 
     with self.test_session() as sess:
       itr = dataset.make_initializable_iterator()
+      next_element = itr.get_next()
       sess.run(
           itr.initializer,
           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
 
       with self.assertRaises(errors.OutOfRangeError):
-        sess.run(itr.get_next())
+        sess.run(next_element)
 
   def testSimpleDirectoryInitializer(self):
     filenames = ['a', 'b', 'c']
@@ -89,6 +92,7 @@ class ListFilesDatasetOpTest(test.TestCase):
 
     with self.test_session() as sess:
       itr = dataset.make_initializable_iterator()
+      next_element = itr.get_next()
       sess.run(
           itr.initializer,
           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
@@ -98,7 +102,7 @@ class ListFilesDatasetOpTest(test.TestCase):
       for filename in filenames:
         full_filenames.append(
             compat.as_bytes(path.join(self.tmp_dir, filename)))
-        produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
 
       self.assertItemsEqual(full_filenames, produced_filenames)
 
@@ -114,6 +118,7 @@ class ListFilesDatasetOpTest(test.TestCase):
 
     with self.test_session() as sess:
       itr = dataset.make_initializable_iterator()
+      next_element = itr.get_next()
       sess.run(
           itr.initializer,
           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')})
@@ -123,7 +128,7 @@ class ListFilesDatasetOpTest(test.TestCase):
       for filename in filenames[1:-1]:
         full_filenames.append(
             compat.as_bytes(path.join(self.tmp_dir, filename)))
-        produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
       self.assertItemsEqual(full_filenames, produced_filenames)
 
       with self.assertRaises(errors.OutOfRangeError):
@@ -138,6 +143,7 @@ class ListFilesDatasetOpTest(test.TestCase):
 
     with self.test_session() as sess:
       itr = dataset.make_initializable_iterator()
+      next_element = itr.get_next()
       sess.run(
           itr.initializer,
           feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})
@@ -147,13 +153,44 @@ class ListFilesDatasetOpTest(test.TestCase):
       for filename in filenames[1:]:
         full_filenames.append(
             compat.as_bytes(path.join(self.tmp_dir, filename)))
-        produced_filenames.append(compat.as_bytes(sess.run(itr.get_next())))
+        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
 
       self.assertItemsEqual(full_filenames, produced_filenames)
 
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(itr.get_next())
 
+  def testNoShuffle(self):
+    filenames = ['a', 'b', 'c']
+    self._touchTempFiles(filenames)
+
+    # Repeat the list twice and ensure that the order is the same each time.
+    # NOTE(mrry): This depends on an implementation detail of `list_files()`,
+    # which is that the list of files is captured when the iterator is
+    # initialized. Otherwise, or if e.g. the iterator were initialized more than
+    # once, it's possible that the non-determinism of `tf.matching_files()`
+    # would cause this test to fail. However, it serves as a useful confirmation
+    # that the `shuffle=False` argument is working as intended.
+    # TODO(b/73959787): Provide some ordering guarantees so that this test is
+    # more meaningful.
+    dataset = dataset_ops.Dataset.list_files(
+        path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
+    with self.test_session() as sess:
+      itr = dataset.make_one_shot_iterator()
+      next_element = itr.get_next()
+
+      full_filenames = []
+      produced_filenames = []
+      for filename in filenames * 2:
+        full_filenames.append(
+            compat.as_bytes(path.join(self.tmp_dir, filename)))
+        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(itr.get_next())
+      self.assertItemsEqual(full_filenames, produced_filenames)
+      self.assertEqual(produced_filenames[:len(filenames)],
+                       produced_filenames[len(filenames):])
+
 
 if __name__ == '__main__':
   test.main()
index 5751f35..7c5aa4c 100644 (file)
@@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_io_ops
 from tensorflow.python.ops import math_ops
@@ -557,7 +558,7 @@ class Dataset(object):
     return PrefetchDataset(self, buffer_size)
 
   @staticmethod
-  def list_files(file_pattern):
+  def list_files(file_pattern, shuffle=None):
     """A dataset of all files matching a pattern.
 
     Example:
@@ -570,16 +571,31 @@ class Dataset(object):
         - /path/to/dir/b.py
         - /path/to/dir/c.py
 
-    NOTE: The order of the file names returned can be non-deterministic.
+    NOTE: The order of the file names returned can be non-deterministic even
+    when `shuffle` is `False`.
 
     Args:
       file_pattern: A string or scalar string `tf.Tensor`, representing
         the filename pattern that will be matched.
+      shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
+        Defaults to `True`.
 
     Returns:
      Dataset: A `Dataset` of strings corresponding to file names.
     """
-    return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
+    # TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
+    # behavior deterministic (e.g. by sorting the filenames).
+    if shuffle is None:
+      shuffle = True
+    matching_files = gen_io_ops.matching_files(file_pattern)
+    dataset = Dataset.from_tensor_slices(matching_files)
+    if shuffle:
+      # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
+      # list of files might be empty.
+      buffer_size = math_ops.maximum(
+          array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
+      dataset = dataset.shuffle(buffer_size)
+    return dataset
 
   def repeat(self, count=None):
     """Repeats this dataset `count` times.
index 42de5c0..0900ada 100644 (file)
@@ -64,7 +64,7 @@ tf_class {
   }
   member_method {
     name: "list_files"
-    argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "make_initializable_iterator"
index e2fc8d6..7b16ac9 100644 (file)
@@ -65,7 +65,7 @@ tf_class {
   }
   member_method {
     name: "list_files"
-    argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "make_initializable_iterator"
index 709ec12..9cf5f2a 100644 (file)
@@ -65,7 +65,7 @@ tf_class {
   }
   member_method {
     name: "list_files"
-    argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "make_initializable_iterator"
index 7263230..8c3d669 100644 (file)
@@ -65,7 +65,7 @@ tf_class {
   }
   member_method {
     name: "list_files"
-    argspec: "args=[\'file_pattern\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'file_pattern\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "make_initializable_iterator"