Add shape check to TextLineDataset op
authorimsheridan <xiaoyudong0512@gmail.com>
Wed, 18 Apr 2018 12:12:14 +0000 (20:12 +0800)
committerimsheridan <xiaoyudong0512@gmail.com>
Wed, 18 Apr 2018 12:12:14 +0000 (20:12 +0800)
tensorflow/core/ops/dataset_ops.cc

index 7f4d63b..f3b51d0 100644 (file)
@@ -383,10 +383,12 @@ REGISTER_OP("TextLineDataset")
     .Output("handle: variant")
     .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
                       // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape);  // TODO(mrry): validate
-                                                // that `filenames` is
-                                                // a scalar or a
-                                                // vector.
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      // `filenames` must be a scalar or a vector.
+      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+      return shape_inference::ScalarShape(c);
+    });
 
 REGISTER_OP("SqlDataset")
     .Input("driver_name: string")