Handle negative values when slicing symbolic shapes
authorBenoit Steiner <bsteiner@google.com>
Thu, 3 May 2018 00:00:16 +0000 (17:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 00:04:03 +0000 (17:04 -0700)
PiperOrigin-RevId: 195176133

tensorflow/core/grappler/costs/graph_properties.cc

index 23d25cb..eaf7634 100644 (file)
@@ -804,11 +804,16 @@ class SymbolicShapeRefiner {
           int64 start = slice_offset->dtype() == DT_INT32
                             ? slice_offset->flat<int32>()(0)
                             : slice_offset->flat<int64>()(0);
-          int64 end = start + (slice_size->dtype() == DT_INT32
-                                   ? slice_size->flat<int32>()(0)
-                                   : slice_size->flat<int64>()(0));
+          int64 size =
+              (slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
+                                               : slice_size->flat<int64>()(0));
           ShapeHandle result;
-          TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
+          if (size == -1) {
+            TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
+          } else {
+            int64 end = start + size;
+            TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
+          }
           c->output_tensors_as_shapes.resize(1);
           c->output_tensors_as_shapes[0] = result;
         }