[tf.data] Raise error when window size is 0 in `tf.contrib.data.group_by_window()`.
authorDerek Murray <mrry@google.com>
Tue, 27 Mar 2018 21:23:28 +0000 (14:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 21:25:41 +0000 (14:25 -0700)
PiperOrigin-RevId: 190673466

tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
tensorflow/core/kernels/data/group_by_window_dataset_op.cc

index d013189..6002cc7 100644 (file)
@@ -104,6 +104,21 @@ class GroupByWindowTest(test.TestCase):
       self.assertAllEqual([0, 0, 0], sess.run(get_next))
       self.assertAllEqual([1], sess.run(get_next))
 
+  def testEmpty(self):
+    iterator = (
+        dataset_ops.Dataset.range(4).apply(
+            grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+        .make_initializable_iterator())
+    init_op = iterator.initializer
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      sess.run(init_op)
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          "Window size must be greater than zero, but got 0."):
+        print(sess.run(get_next))
+
   def testReduceFuncError(self):
     components = np.random.randint(100, size=(200,)).astype(np.int64)
 
index 834c06b..46f43dd 100644 (file)
@@ -263,6 +263,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
                 }
                 const int64 window_size =
                     window_size_func_output[0].scalar<int64>()();
+                if (window_size <= 0) {
+                  return errors::InvalidArgument(
+                      "Window size must be greater than zero, but got ",
+                      window_size, ".");
+                }
                 window_sizes_[key] = window_size;
               }