Make tf.clip_by_value not crash on empty tensors
authorGeoffrey Irving <irving@naml.us>
Wed, 16 May 2018 22:47:58 +0000 (15:47 -0700)
committerGeoffrey Irving <irving@naml.us>
Wed, 16 May 2018 22:47:58 +0000 (15:47 -0700)
Also rearrange the code to remove duplication.  No tests yet; I'll leave
refactoring the test cases for empty tensor coverage to someone else.

Fixes #19337.

tensorflow/core/kernels/cwise_op_clip.cc

index 14d889e..49b90e8 100644 (file)
@@ -33,52 +33,41 @@ class ClipOp : public OpKernel {
     const Tensor& in0 = ctx->input(0);
     const Tensor& in1 = ctx->input(1);
     const Tensor& in2 = ctx->input(2);
+    OP_REQUIRES(ctx, (in0.shape() == in1.shape() ||
+                      TensorShapeUtils::IsScalar(in1.shape())) &&
+                     (in0.shape() == in2.shape() ||
+                      TensorShapeUtils::IsScalar(in2.shape())),
+                errors::InvalidArgument(
+                    "clip_value_min and clip_value_max must be either of "
+                    "the same shape as input, or a scalar. ",
+                    "input shape: ", in0.shape().DebugString(),
+                    "clip_value_min shape: ", in1.shape().DebugString(),
+                    "clip_value_max shape: ", in2.shape().DebugString()));
+
+    Tensor* out = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
+    if (out->NumElements() == 0) return;  // Nothing to do for empty output
 
     auto in0_flat = in0.flat<T>();
     auto in1_flat = in1.flat<T>();
     auto in2_flat = in2.flat<T>();
+    auto out_flat = out->flat<T>();
     const Device& d = ctx->eigen_device<Device>();
 
-    Tensor* out = nullptr;
-    OP_REQUIRES_OK(
-        ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
-    auto out_flat = out->flat<T>();
     if (in1.shape() == in2.shape()) {
       if (in0.shape() == in1.shape()) {
         functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
                                             out_flat);
       } else {
-        OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in1.shape()),
-                    errors::InvalidArgument(
-                        "clip_value_min and clip_value_max must be either of "
-                        "the same shape as input, or a scalar. ",
-                        "input shape: ", in0.shape().DebugString(),
-                        "clip_value_min shape: ", in1.shape().DebugString(),
-                        "clip_value_max shape: ", in2.shape().DebugString()));
         functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
                                           out_flat);
       }
     } else {
       if (in0.shape() == in1.shape()) {
-        OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(in2.shape()),
-                    errors::InvalidArgument(
-                        "clip_value_min and clip_value_max must be either of "
-                        "the same shape as input, or a scalar. ",
-                        "input shape: ", in0.shape().DebugString(),
-                        "clip_value_min shape: ", in1.shape().DebugString(),
-                        "clip_value_max shape: ", in2.shape().DebugString()));
         functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
                                                out_flat);
       } else {
-        OP_REQUIRES(ctx,
-                    (in0.shape() == in2.shape() &&
-                     TensorShapeUtils::IsScalar(in1.shape())),
-                    errors::InvalidArgument(
-                        "clip_value_min and clip_value_max must be either of "
-                        "the same shape as input, or a scalar. ",
-                        "input shape: ", in0.shape().DebugString(),
-                        "clip_value_min shape: ", in1.shape().DebugString(),
-                        "clip_value_max shape: ", in2.shape().DebugString()));
         functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
                                                 out_flat);
       }