From c7334fef9d1173525f6111b8ab50360b6531d76b Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 21 Mar 2018 18:02:01 -0700 Subject: [PATCH] [tf.data] Do not crash when combining .cache().take().repeat() Currently, if the .cache() iterator is not fully consumed before being repeated, it will cause an exception to be raised to Python. Instead, cache should act as an identity transformation and log an error, as this will not affect the correctness of the user's program (at the cost of an unexpected performance cost: i.e. not actually caching). PiperOrigin-RevId: 189999552 --- tensorflow/core/kernels/data/cache_dataset_ops.cc | 17 ++++++++++++++++- .../python/data/kernel_tests/cache_dataset_op_test.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index f0a2192..4b4728d 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -308,6 +308,21 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { input_impl_(params.dataset->input_->MakeIterator(params.prefix)), cache_(new std::vector>) {} + ~MemoryWriterIterator() override { + mutex_lock l(mu_); + if (cache_) { + LOG(ERROR) + << "The calling iterator did not fully read the dataset we were " + "attempting to cache. In order to avoid unexpected truncation " + "of the sequence, the current [partially cached] sequence " + "will be dropped. This can occur if you have a sequence " + "similar to `dataset.cache().take(k).repeat()`. Instead, swap " + "the order (i.e. `dataset.take(k).cache().repeat()`)"; + mutex_lock l2(dataset()->mu_); + dataset()->writer_iterator_created_ = false; + } + } + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) override { @@ -318,7 +333,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { // Guard on cache_ to not crash if GetNext is called a second time // after *end_of_sequence == true if (cache_) { - mutex_lock l2(dataset()->mu_); + mutex_lock l(dataset()->mu_); DCHECK(dataset()->writer_iterator_created_); DCHECK(!dataset()->cache_); cache_.swap(dataset()->cache_); diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index 02720a2..25269dc 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -297,6 +297,21 @@ class MemoryCacheDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(i2.get_next()) + def testCacheTakeRepeat(self): + dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2) + itr = dataset.make_one_shot_iterator() + n = itr.get_next() + + expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + + with self.test_session() as sess: + for i, expected in enumerate(expected_values): + self.assertEqual(expected, sess.run(n), + "Unexpected value at index %s" % i) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + if __name__ == "__main__": test.main() -- 2.7.4