Added `drop_final_batch` argument to make_batched_features_dataset. This allows the...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:37:19 +0000 (18:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:39:37 +0000 (18:39 -0700)
PiperOrigin-RevId: 191831842

tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
tensorflow/contrib/data/python/ops/readers.py

index 6ee1b57..f3e9302 100644 (file)
@@ -271,7 +271,8 @@ class ReadBatchFeaturesTest(test.TestCase):
                            reader_num_threads=1,
                            parser_num_threads=1,
                            shuffle=False,
-                           shuffle_seed=None):
+                           shuffle_seed=None,
+                           drop_final_batch=False):
     self.filenames = filenames
     self.num_epochs = num_epochs
     self.batch_size = batch_size
@@ -289,7 +290,8 @@ class ReadBatchFeaturesTest(test.TestCase):
         shuffle=shuffle,
         shuffle_seed=shuffle_seed,
         reader_num_threads=reader_num_threads,
-        parser_num_threads=parser_num_threads).make_one_shot_iterator(
+        parser_num_threads=parser_num_threads,
+        drop_final_batch=drop_final_batch).make_one_shot_iterator(
         ).get_next()
 
   def _record(self, f, r):
@@ -559,6 +561,20 @@ class ReadBatchFeaturesTest(test.TestCase):
               with self.assertRaises(errors.OutOfRangeError):
                 self._next_actual_batch(sess)
 
+  def testDropFinalBatch(self):
+    for batch_size in [1, 2]:
+      for num_epochs in [1, 10]:
+        with ops.Graph().as_default():
+          # Basic test: read from file 0.
+          self.outputs = self._read_batch_features(
+              filenames=self.test_filenames[0],
+              num_epochs=num_epochs,
+              batch_size=batch_size,
+              drop_final_batch=True)
+          for _, tensor in self.outputs.items():
+            if isinstance(tensor, ops.Tensor):  # Guard against SparseTensor.
+              self.assertEqual(tensor.shape[0], batch_size)
+
 
 class MakeCsvDatasetTest(test.TestCase):
 
index 9a48aa0..b8eb099 100644 (file)
@@ -370,7 +370,8 @@ def make_batched_features_dataset(file_pattern,
                                   prefetch_buffer_size=1,
                                   reader_num_threads=1,
                                   parser_num_threads=2,
-                                  sloppy_ordering=False):
+                                  sloppy_ordering=False,
+                                  drop_final_batch=False):
   """Returns a `Dataset` of feature dictionaries from `Example` protos.
 
   Example:
@@ -443,6 +444,9 @@ def make_batched_features_dataset(file_pattern,
       produced is deterministic prior to shuffling (elements are still
       randomized if `shuffle=True`. Note that if the seed is set, then order
       of elements after shuffling is deterministic). Defaults to `False`.
+    drop_final_batch: If `True`, and the batch size does not evenly divide the
+      input dataset size, the final smaller batch will be dropped. Defaults to
+      `False`.
 
   Returns:
     A dataset of `dict` elements. Each `dict` maps feature keys to
@@ -481,7 +485,10 @@ def make_batched_features_dataset(file_pattern,
   elif shuffle:
     dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
 
-  dataset = dataset.batch(batch_size)
+  if drop_final_batch:
+    dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
+  else:
+    dataset = dataset.batch(batch_size)
 
   # Parse `Example` tensors to a dictionary of `Feature` tensors.
   dataset = dataset.map(