From 6bb4f7abb03a7904fecc5b61e3ed37e9b663d6b0 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 27 Mar 2018 14:23:28 -0700 Subject: [PATCH] [tf.data] Raise error when window size is 0 in `tf.contrib.data.group_by_window()`. PiperOrigin-RevId: 190673466 --- .../contrib/data/python/kernel_tests/bucketing_test.py | 15 +++++++++++++++ .../core/kernels/data/group_by_window_dataset_op.cc | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index d013189..6002cc7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -104,6 +104,21 @@ class GroupByWindowTest(test.TestCase): self.assertAllEqual([0, 0, 0], sess.run(get_next)) self.assertAllEqual([1], sess.run(get_next)) + def testEmpty(self): + iterator = ( + dataset_ops.Dataset.range(4).apply( + grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Window size must be greater than zero, but got 0."): + print(sess.run(get_next)) + def testReduceFuncError(self): components = np.random.randint(100, size=(200,)).astype(np.int64) diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 834c06b..46f43dd 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -263,6 +263,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } const int64 window_size = window_size_func_output[0].scalar()(); + if (window_size <= 0) { + return errors::InvalidArgument( + "Window size must be greater than zero, but got ", + window_size, "."); + } window_sizes_[key] = window_size; } -- 2.7.4