Improve error handling in strided_slice_op to fail more gracefully and return an...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Feb 2018 05:25:22 +0000 (21:25 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187126888

tensorflow/core/kernels/strided_slice_op.cc

index 7745eff..1e3e92a 100644 (file)
@@ -109,17 +109,27 @@ class StridedSliceOp : public OpKernel {
     if (is_identity) {
       VLOG(1) << "Strided slice identity ";
       Tensor tmp;
-      CHECK(tmp.CopyFrom(input, final_shape));
+      OP_REQUIRES(context, tmp.CopyFrom(input, final_shape),
+                  errors::Internal("Copy failed"));
       context->set_output(0, tmp);
       return;
     }
 
     // Optimization #2, slice is memory contiguous (only occurs in dim 0)
     if (slice_dim0 && IsDim0SliceAligned<T>(input.shape(), begin[0], end[0])) {
-      CHECK_GE(input.dims(), 1);  // Otherwise, is_identity should be true.
+      OP_REQUIRES(context, input.dims() >= 1,
+                  errors::InvalidArgument(
+                      "Input must have rank at least 1, got: ", input.dims()));
+      // Otherwise, is_identity should be true.
       VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString();
+      OP_REQUIRES(
+          context, begin[0] <= end[0],
+          errors::InvalidArgument("begin[0] (", begin[0],
+                                  ") must less or equal to end[0] (", end[0]));
+      Tensor slice = input.Slice(begin[0], end[0]);
       Tensor tmp;
-      CHECK(tmp.CopyFrom(input.Slice(begin[0], end[0]), final_shape));
+      OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape),
+                  errors::Internal("Copy failed"));
       context->set_output(0, tmp);
       return;
     }
@@ -238,7 +248,8 @@ class StridedSliceGradOp : public OpKernel {
 
     if (processing_shape.dims() == 0) {
       auto in = context->input(4);
-      CHECK(result->CopyFrom(in, processing_shape));
+      OP_REQUIRES(context, result->CopyFrom(in, processing_shape),
+                  errors::Internal("Copy failed"));
       return;
     }