Shape validation with random/shuffle related Dataset ops (#18682)
authorYong Tang <yong.tang.github@outlook.com>
Thu, 19 Apr 2018 16:13:35 +0000 (09:13 -0700)
committerDerek Murray <derek.murray@gmail.com>
Thu, 19 Apr 2018 16:13:35 +0000 (09:13 -0700)
* Add shape check for CacheDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add shape check for ShuffleAndRepeatDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add check for ShuffleDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add shape check for RandomDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add RangeDataset shape check

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Sanitize with clang-format -i --style=Google

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/core/ops/dataset_ops.cc

index c63e485..dae0c0e 100644 (file)
@@ -357,7 +357,14 @@ REGISTER_OP("RangeDataset")
     .Attr("output_shapes: list(shape) >= 1")
     .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
                       // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // start, stop, and step should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("RandomDataset")
     .Input("seed: int64")
@@ -367,7 +374,13 @@ REGISTER_OP("RandomDataset")
     .Attr("output_shapes: list(shape) >= 1")
     .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
                       // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // buffer_size, seed, and seed2 should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("ShuffleDataset")
     .Input("input_dataset: variant")
@@ -378,7 +391,14 @@ REGISTER_OP("ShuffleDataset")
     .Attr("reshuffle_each_iteration: bool = true")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // buffer_size, seed, and seed2 should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("ShuffleAndRepeatDataset")
     .Input("input_dataset: variant")
@@ -389,7 +409,15 @@ REGISTER_OP("ShuffleAndRepeatDataset")
     .Output("handle: variant")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // buffer_size, seed, seed2, and count should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("CacheDataset")
     .Input("input_dataset: variant")
@@ -397,7 +425,12 @@ REGISTER_OP("CacheDataset")
     .Output("handle: variant")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // filename should be a scalar.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("TextLineDataset")
     .Input("filenames: string")