[tf.data] Handle a function-raised OutOfRange error correctly in ParallelMapDataset.
authorDerek Murray <mrry@google.com>
Thu, 22 Feb 2018 23:06:45 +0000 (15:06 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 23:10:48 +0000 (15:10 -0800)
PiperOrigin-RevId: 186680982

tensorflow/core/kernels/data/parallel_map_dataset_op.cc
tensorflow/python/data/kernel_tests/map_dataset_op_test.py

index bc4426a..33053b1 100644 (file)
@@ -199,7 +199,14 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
           }
         }
         ++num_outputs_consumed_;
-        return result->status;
+        if (errors::IsOutOfRange(result->status)) {
+          // `f` may deliberately raise `errors::OutOfRange` to indicate
+          // that we should terminate the iteration early.
+          *end_of_sequence = true;
+          return Status::OK();
+        } else {
+          return result->status;
+        }
       }
 
      protected:
index 04d1abd..0791c61 100644 (file)
@@ -602,6 +602,28 @@ class MapDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  def testParallelMapOutOfRangeError(self):
+    def raising_py_func(i):
+      if i == 100:
+        raise StopIteration()
+      else:
+        return i
+
+    iterator = (
+        dataset_ops.Dataset.range(105)
+        .map(lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64),
+             num_parallel_calls=2)
+        .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(init_op)
+      for i in range(100):
+        self.assertEqual(i, sess.run(get_next))
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(get_next)
+
 
 class MapDatasetBenchmark(test.Benchmark):