Fix allow_smaller_final_batches for bucket_by_sequence_length.
authorMatthew Schulkind <matt@hyperscience.com>
Fri, 22 Dec 2017 18:32:27 +0000 (10:32 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 22 Dec 2017 18:40:38 +0000 (10:40 -0800)
Closes #14420.

PiperOrigin-RevId: 179940036

tensorflow/contrib/training/python/training/bucket_ops.py
tensorflow/contrib/training/python/training/bucket_ops_test.py

index 95fbc50cba73b25b748c31ecd443eb19c0b6fc8a..e7f23edc901eacfa3a753792c2dbf738bb5a9421 100644 (file)
@@ -265,16 +265,22 @@ def bucket(tensors,
         for i, (q, bs) in enumerate(zip(bucket_queues, batch_size))
     ]
 
-    for i, q in enumerate(bucket_queues):
-      queue_runner.add_queue_runner(
-          queue_runner.QueueRunner(
-              q, [enqueues_to_top[i]],
-              queue_closed_exception_types=(errors.OutOfRangeError,
-                                            errors.CancelledError)))
+    queue_runner.add_queue_runner(
+        queue_runner.QueueRunner(
+            bucket_queues[0], enqueues_to_top,
+            close_op=top_queue.close(),
+            cancel_op=top_queue.close(cancel_pending_enqueues=True),
+            queue_closed_exception_types=(errors.OutOfRangeError,
+                                          errors.CancelledError)))
     queue_runner.add_queue_runner(
         queue_runner.QueueRunner(
             top_queue,
             bucket_enqueue_ops,
+            close_op=control_flow_ops.group(
+                *[q.close() for q in bucket_queues]),
+            cancel_op=control_flow_ops.group(
+                *[q.close(cancel_pending_enqueues=True)
+                  for q in bucket_queues]),
             queue_closed_exception_types=(errors.OutOfRangeError,
                                           errors.CancelledError)))
 
index 330bee8a3fb13cd703fb260952d33e58623ca09c..504f1fcd417f99a8aaa72504f1852e523da1a4c9 100644 (file)
@@ -23,6 +23,7 @@ import numpy as np
 from tensorflow.contrib.training.python.training import bucket_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
@@ -321,7 +322,8 @@ class BucketBySequenceLengthTest(test.TestCase):
 
   def _testBucketBySequenceLength(self,
                                   allow_small_batch,
-                                  bucket_capacities=None):
+                                  bucket_capacities=None,
+                                  drain_entire_queue=True):
     ops.reset_default_graph()
 
     # All inputs must be identical lengths across tuple index.
@@ -339,6 +341,7 @@ class BucketBySequenceLengthTest(test.TestCase):
 
     batch_size = 8
     bucket_boundaries = [3, 4, 5, 10]
+    num_pairs_to_enqueue = 50 * batch_size + 100
 
     # Make capacity very large so we can feed all the inputs in the
     # main thread without blocking
@@ -366,34 +369,47 @@ class BucketBySequenceLengthTest(test.TestCase):
                      [expected_batch_size, labels_len])
 
     def _read_test(sess):
-      for _ in range(50):
-        (out_lengths, (data, labels)) = sess.run(
-            (out_lengths_t, data_and_labels_t))
+      num_pairs_dequeued = 0
+      try:
+        while drain_entire_queue or num_pairs_dequeued < 40 * batch_size:
+          (out_lengths, (data, labels)) = sess.run(
+              (out_lengths_t, data_and_labels_t))
+          num_pairs_dequeued += out_lengths.shape[0]
+          if allow_small_batch:
+            self.assertEqual(data_len, data.shape[1])
+            self.assertEqual(labels_len, labels.shape[1])
+            self.assertGreaterEqual(batch_size, out_lengths.shape[0])
+            self.assertGreaterEqual(batch_size, data.shape[0])
+            self.assertGreaterEqual(batch_size, labels.shape[0])
+          else:
+            self.assertEqual((batch_size, data_len), data.shape)
+            self.assertEqual((batch_size, labels_len), labels.shape)
+            self.assertEqual((batch_size,), out_lengths.shape)
+          for (lr, dr, tr) in zip(out_lengths, data, labels):
+            # Make sure length matches data (here it's the same value).
+            self.assertEqual(dr[0], lr)
+            # Make sure data & labels match.
+            self.assertEqual(dr[0], int(tr[0].decode("ascii")))
+            # Make sure for each row, data came from the same bucket.
+            self.assertEqual(
+                _which_bucket(bucket_boundaries, dr[0]),
+                _which_bucket(bucket_boundaries, dr[1]))
+      except errors.OutOfRangeError:
         if allow_small_batch:
-          self.assertEqual(data_len, data.shape[1])
-          self.assertEqual(labels_len, labels.shape[1])
-          self.assertGreaterEqual(batch_size, out_lengths.shape[0])
-          self.assertGreaterEqual(batch_size, data.shape[0])
-          self.assertGreaterEqual(batch_size, labels.shape[0])
+          self.assertEqual(num_pairs_to_enqueue, num_pairs_dequeued)
         else:
-          self.assertEqual((batch_size, data_len), data.shape)
-          self.assertEqual((batch_size, labels_len), labels.shape)
-          self.assertEqual((batch_size,), out_lengths.shape)
-        for (lr, dr, tr) in zip(out_lengths, data, labels):
-          # Make sure length matches data (here it's the same value).
-          self.assertEqual(dr[0], lr)
-          # Make sure data & labels match.
-          self.assertEqual(dr[0], int(tr[0].decode("ascii")))
-          # Make sure for each row, data came from the same bucket.
-          self.assertEqual(
-              _which_bucket(bucket_boundaries, dr[0]),
-              _which_bucket(bucket_boundaries, dr[1]))
+          # Maximum left over in the queues should be at most one less than the
+          # batch_size, for every bucket.
+          num_buckets = len(bucket_boundaries) + 2
+          self.assertLessEqual(
+              num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
+              num_pairs_dequeued)
 
     with self.test_session() as sess:
       coord = coordinator.Coordinator()
 
       # Feed the inputs, then close the input thread.
-      for _ in range(50 * batch_size + 100):
+      for _ in range(num_pairs_to_enqueue):
         which = random.randint(0, len(input_pairs) - 1)
         length, pair = input_pairs[which]
         sess.run(input_enqueue_op,
@@ -425,6 +441,10 @@ class BucketBySequenceLengthTest(test.TestCase):
     self._testBucketBySequenceLength(allow_small_batch=True,
                                      bucket_capacities=capacities)
 
+  def testBucketBySequenceLengthShutdown(self):
+    self._testBucketBySequenceLength(allow_small_batch=True,
+                                     drain_entire_queue=False)
+
 
 if __name__ == "__main__":
   test.main()