for i, (q, bs) in enumerate(zip(bucket_queues, batch_size))
]
- for i, q in enumerate(bucket_queues):
- queue_runner.add_queue_runner(
- queue_runner.QueueRunner(
- q, [enqueues_to_top[i]],
- queue_closed_exception_types=(errors.OutOfRangeError,
- errors.CancelledError)))
+ queue_runner.add_queue_runner(
+ queue_runner.QueueRunner(
+ bucket_queues[0], enqueues_to_top,
+ close_op=top_queue.close(),
+ cancel_op=top_queue.close(cancel_pending_enqueues=True),
+ queue_closed_exception_types=(errors.OutOfRangeError,
+ errors.CancelledError)))
queue_runner.add_queue_runner(
queue_runner.QueueRunner(
top_queue,
bucket_enqueue_ops,
+ close_op=control_flow_ops.group(
+ *[q.close() for q in bucket_queues]),
+ cancel_op=control_flow_ops.group(
+ *[q.close(cancel_pending_enqueues=True)
+ for q in bucket_queues]),
queue_closed_exception_types=(errors.OutOfRangeError,
errors.CancelledError)))
from tensorflow.contrib.training.python.training import bucket_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
def _testBucketBySequenceLength(self,
allow_small_batch,
- bucket_capacities=None):
+ bucket_capacities=None,
+ drain_entire_queue=True):
ops.reset_default_graph()
# All inputs must be identical lengths across tuple index.
batch_size = 8
bucket_boundaries = [3, 4, 5, 10]
+ num_pairs_to_enqueue = 50 * batch_size + 100
# Make capacity very large so we can feed all the inputs in the
# main thread without blocking
[expected_batch_size, labels_len])
def _read_test(sess):
- for _ in range(50):
- (out_lengths, (data, labels)) = sess.run(
- (out_lengths_t, data_and_labels_t))
+ num_pairs_dequeued = 0
+ try:
+ while drain_entire_queue or num_pairs_dequeued < 40 * batch_size:
+ (out_lengths, (data, labels)) = sess.run(
+ (out_lengths_t, data_and_labels_t))
+ num_pairs_dequeued += out_lengths.shape[0]
+ if allow_small_batch:
+ self.assertEqual(data_len, data.shape[1])
+ self.assertEqual(labels_len, labels.shape[1])
+ self.assertGreaterEqual(batch_size, out_lengths.shape[0])
+ self.assertGreaterEqual(batch_size, data.shape[0])
+ self.assertGreaterEqual(batch_size, labels.shape[0])
+ else:
+ self.assertEqual((batch_size, data_len), data.shape)
+ self.assertEqual((batch_size, labels_len), labels.shape)
+ self.assertEqual((batch_size,), out_lengths.shape)
+ for (lr, dr, tr) in zip(out_lengths, data, labels):
+ # Make sure length matches data (here it's the same value).
+ self.assertEqual(dr[0], lr)
+ # Make sure data & labels match.
+ self.assertEqual(dr[0], int(tr[0].decode("ascii")))
+ # Make sure for each row, data came from the same bucket.
+ self.assertEqual(
+ _which_bucket(bucket_boundaries, dr[0]),
+ _which_bucket(bucket_boundaries, dr[1]))
+ except errors.OutOfRangeError:
if allow_small_batch:
- self.assertEqual(data_len, data.shape[1])
- self.assertEqual(labels_len, labels.shape[1])
- self.assertGreaterEqual(batch_size, out_lengths.shape[0])
- self.assertGreaterEqual(batch_size, data.shape[0])
- self.assertGreaterEqual(batch_size, labels.shape[0])
+ self.assertEqual(num_pairs_to_enqueue, num_pairs_dequeued)
else:
- self.assertEqual((batch_size, data_len), data.shape)
- self.assertEqual((batch_size, labels_len), labels.shape)
- self.assertEqual((batch_size,), out_lengths.shape)
- for (lr, dr, tr) in zip(out_lengths, data, labels):
- # Make sure length matches data (here it's the same value).
- self.assertEqual(dr[0], lr)
- # Make sure data & labels match.
- self.assertEqual(dr[0], int(tr[0].decode("ascii")))
- # Make sure for each row, data came from the same bucket.
- self.assertEqual(
- _which_bucket(bucket_boundaries, dr[0]),
- _which_bucket(bucket_boundaries, dr[1]))
+ # Maximum left over in the queues should be at most one less than the
+ # batch_size, for every bucket.
+ num_buckets = len(bucket_boundaries) + 2
+ self.assertLessEqual(
+ num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
+ num_pairs_dequeued)
with self.test_session() as sess:
coord = coordinator.Coordinator()
# Feed the inputs, then close the input thread.
- for _ in range(50 * batch_size + 100):
+ for _ in range(num_pairs_to_enqueue):
which = random.randint(0, len(input_pairs) - 1)
length, pair = input_pairs[which]
sess.run(input_enqueue_op,
self._testBucketBySequenceLength(allow_small_batch=True,
bucket_capacities=capacities)
+ def testBucketBySequenceLengthShutdown(self):
+ self._testBucketBySequenceLength(allow_small_batch=True,
+ drain_entire_queue=False)
+
if __name__ == "__main__":
test.main()