reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
- shuffle_seed=None):
+ shuffle_seed=None,
+ drop_final_batch=False):
self.filenames = filenames
self.num_epochs = num_epochs
self.batch_size = batch_size
shuffle=shuffle,
shuffle_seed=shuffle_seed,
reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch).make_one_shot_iterator(
).get_next()
def _record(self, f, r):
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ self.outputs = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True)
+ for _, tensor in self.outputs.items():
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
class MakeCsvDatasetTest(test.TestCase):
prefetch_buffer_size=1,
reader_num_threads=1,
parser_num_threads=2,
- sloppy_ordering=False):
+ sloppy_ordering=False,
+ drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
Example:
produced is deterministic prior to shuffling (elements are still
randomized if `shuffle=True`. Note that if the seed is set, then order
of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
Returns:
A dataset of `dict` elements. Each `dict` maps feature keys to
elif shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- dataset = dataset.batch(batch_size)
+ if drop_final_batch:
+ dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
+ else:
+ dataset = dataset.batch(batch_size)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
dataset = dataset.map(