From 6006f46dd7531b112360b831aa61de6c46618166 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Thu, 22 Feb 2018 15:06:45 -0800 Subject: [PATCH] [tf.data] Handle a function-raised OutOfRange error correctly in ParallelMapDataset. PiperOrigin-RevId: 186680982 --- .../core/kernels/data/parallel_map_dataset_op.cc | 9 ++++++++- .../data/kernel_tests/map_dataset_op_test.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index bc4426a..33053b1 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -199,7 +199,14 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } } ++num_outputs_consumed_; - return result->status; + if (errors::IsOutOfRange(result->status)) { + // `f` may deliberately raise `errors::OutOfRange` to indicate + // that we should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } else { + return result->status; + } } protected: diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 04d1abd..0791c61 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -602,6 +602,28 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testParallelMapOutOfRangeError(self): + def raising_py_func(i): + if i == 100: + raise StopIteration() + else: + return i + + iterator = ( + dataset_ops.Dataset.range(105) + .map(lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64), + num_parallel_calls=2) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(100): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + class MapDatasetBenchmark(test.Benchmark): -- 2.7.4