Validation in shape functions of Dataset ops (#18680)
authorYong Tang <yong.tang.github@outlook.com>
Thu, 19 Apr 2018 16:13:53 +0000 (09:13 -0700)
committerDerek Murray <derek.murray@gmail.com>
Thu, 19 Apr 2018 16:13:53 +0000 (09:13 -0700)
* Add shape check for PrependFromQueueAndPaddedBatchDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add comment for shape check

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add shape check for FixedLengthRecordDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add check for filenames as well

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Clang-format -i --style=google for file format

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add shape check for SqlDataset

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
tensorflow/core/ops/dataset_ops.cc

index dae0c0e..869bef8 100644 (file)
@@ -459,7 +459,14 @@ REGISTER_OP("SqlDataset")
     .Attr("output_shapes: list(shape) >= 1")
     .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
                       // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // driver_name, data_source_name, and query should be scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("FixedLengthRecordDataset")
     .Input("filenames: string")
@@ -470,7 +477,18 @@ REGISTER_OP("FixedLengthRecordDataset")
     .Output("handle: variant")
     .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
                       // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // `filenames` must be a scalar or a vector.
+      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+      // header_bytes, record_bytes, footer_bytes, buffer_size should be
+      // scalars.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("TFRecordDataset")
     .Input("filenames: string")
@@ -609,7 +627,12 @@ REGISTER_OP("PrependFromQueueAndPaddedBatchDataset")
     // length of `output_types` is `N`, the `output_shapes` are
     // (as far as possible to tell statically) compatible with `padded_shapes`,
     // and that `padding_values` are all scalars.
-    .SetShapeFn(shape_inference::ScalarShape);
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // batch_size should be a scalar.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("EnqueueInQueueDataset")
     .Input("queue: variant")