[tf.data] Do not crash when combining .cache().take().repeat()
authorBrennan Saeta <saeta@google.com>
Thu, 22 Mar 2018 01:02:01 +0000 (18:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 01:04:21 +0000 (18:04 -0700)
Currently, if the .cache() iterator is not fully consumed before
being repeated, it will cause an exception to be raised to Python.
Instead, cache should act as an identity transformation and log
an error, as this will not affect the correctness of the user's
program (at the cost of an unexpected performance cost: i.e. not
actually caching).

PiperOrigin-RevId: 189999552

tensorflow/core/kernels/data/cache_dataset_ops.cc
tensorflow/python/data/kernel_tests/cache_dataset_op_test.py

index f0a2192..4b4728d 100644 (file)
@@ -308,6 +308,21 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
             input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             cache_(new std::vector<std::vector<Tensor>>) {}
 
+      ~MemoryWriterIterator() override {
+        mutex_lock l(mu_);
+        if (cache_) {
+          LOG(ERROR)
+              << "The calling iterator did not fully read the dataset we were "
+                 "attempting to cache. In order to avoid unexpected truncation "
+                 "of the sequence, the current [partially cached] sequence "
+                 "will be dropped. This can occur if you have a sequence "
+                 "similar to `dataset.cache().take(k).repeat()`. Instead, swap "
+                 "the order (i.e. `dataset.take(k).cache().repeat()`)";
+          mutex_lock l2(dataset()->mu_);
+          dataset()->writer_iterator_created_ = false;
+        }
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
@@ -318,7 +333,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
           // Guard on cache_ to not crash if GetNext is called a second time
           // after *end_of_sequence == true
           if (cache_) {
-            mutex_lock l2(dataset()->mu_);
+            mutex_lock l(dataset()->mu_);
             DCHECK(dataset()->writer_iterator_created_);
             DCHECK(!dataset()->cache_);
             cache_.swap(dataset()->cache_);
index 02720a2..25269dc 100644 (file)
@@ -297,6 +297,21 @@ class MemoryCacheDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(i2.get_next())
 
+  def testCacheTakeRepeat(self):
+    dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
+    itr = dataset.make_one_shot_iterator()
+    n = itr.get_next()
+
+    expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
+
+    with self.test_session() as sess:
+      for i, expected in enumerate(expected_values):
+        self.assertEqual(expected, sess.run(n),
+                         "Unexpected value at index %s" % i)
+
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(itr.get_next())
+
 
 if __name__ == "__main__":
   test.main()