input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
cache_(new std::vector<std::vector<Tensor>>) {}
+ ~MemoryWriterIterator() override {
+ mutex_lock l(mu_);
+ if (cache_) {
+ LOG(ERROR)
+ << "The calling iterator did not fully read the dataset we were "
+ "attempting to cache. In order to avoid unexpected truncation "
+ "of the sequence, the current [partially cached] sequence "
+ "will be dropped. This can occur if you have a sequence "
+ "similar to `dataset.cache().take(k).repeat()`. Instead, swap "
+ "the order (i.e. `dataset.take(k).cache().repeat()`)";
+ mutex_lock l2(dataset()->mu_);
+ dataset()->writer_iterator_created_ = false;
+ }
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
// Guard on cache_ to not crash if GetNext is called a second time
// after *end_of_sequence == true
if (cache_) {
- mutex_lock l2(dataset()->mu_);
+ mutex_lock l(dataset()->mu_);
DCHECK(dataset()->writer_iterator_created_);
DCHECK(!dataset()->cache_);
cache_.swap(dataset()->cache_);
with self.assertRaises(errors.OutOfRangeError):
sess.run(i2.get_next())
+ def testCacheTakeRepeat(self):
+ dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
+ itr = dataset.make_one_shot_iterator()
+ n = itr.get_next()
+
+ expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
+
+ with self.test_session() as sess:
+ for i, expected in enumerate(expected_values):
+ self.assertEqual(expected, sess.run(n),
+ "Unexpected value at index %s" % i)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(itr.get_next())
+
if __name__ == "__main__":
test.main()