Improved shape inference for reshape
authorBenoit Steiner <bsteiner@google.com>
Tue, 27 Mar 2018 19:09:59 +0000 (12:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 19:12:24 +0000 (12:12 -0700)
PiperOrigin-RevId: 190651873

tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
tensorflow/contrib/signal/python/ops/shape_ops.py
tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/array_ops_test.cc

index 1c05235..bc4663f 100644 (file)
@@ -338,7 +338,10 @@ class FrameTest(test.TestCase):
 
   def test_constant_folding(self):
     """frame should be constant foldable for constant inputs."""
-    for pad_end in [False, True]:
+    # Padding is incorrectly defined in shape_ops.py (the rank of the padding
+    # tensor should be equal to the rank of the input tensor + 1): only test
+    # with padding set to False to avoid this.
+    for pad_end in [False]:
       g = ops.Graph()
       with g.as_default():
         frame_length, frame_step = 32, 16
index 1ddc294..97fe208 100644 (file)
@@ -139,6 +139,8 @@ def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
            [[0, pad_samples]],
            array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)],
           0)
+      # TODO(rjryan): the paddings tensor must of rank tf.rank(signal) + 1. This
+      # isn't the case here and should be fixed.
       signal = array_ops.pad(signal, paddings, constant_values=pad_value)
 
       signal_shape = array_ops.shape(signal)
index 39b9246..88d2aa3 100644 (file)
@@ -178,46 +178,88 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
     c->set_output(0, out);
     return Status::OK();
   }
-  DimensionHandle num_in_elems = c->NumElements(in);
-  if (c->FullyDefined(out)) {
-    DimensionHandle num_out_elems = c->NumElements(out);
-    if (c->ValueKnown(num_in_elems) &&
-        c->Value(num_in_elems) != c->Value(num_out_elems)) {
-      return errors::InvalidArgument(
-          "Cannot reshape a tensor with ", c->DebugString(num_in_elems),
-          " elements to shape ", c->DebugString(out), " (",
-          c->DebugString(num_out_elems), " elements)");
-    }
-    c->set_output(0, out);
-    return Status::OK();
-  }
 
-  if (c->ValueKnown(num_in_elems)) {
+  if (c->RankKnown(out) && c->RankKnown(in)) {
     // We don't know the number of output elements, but we can try to infer
     // the missing dimension.
-    int32 unknown_idx = -1;
     bool too_many_unknown = false;
-    DimensionHandle known_elems = c->MakeDim(1);
-    for (int32 i = 0; i < c->Rank(out); ++i) {
-      DimensionHandle dim = c->Dim(out, i);
-      if (!c->ValueKnown(dim)) {
-        if (unknown_idx >= 0) {
-          too_many_unknown = true;
-          break;
+    int32 out_unknown_idx = -1;
+
+    DimensionHandle known_out_elems = c->NumElements(out);
+    if (!c->ValueKnown(known_out_elems)) {
+      known_out_elems = c->MakeDim(1);
+      for (int32 i = 0; i < c->Rank(out); ++i) {
+        DimensionHandle dim = c->Dim(out, i);
+        if (!c->ValueKnown(dim)) {
+          if (out_unknown_idx >= 0) {
+            too_many_unknown = true;
+            break;
+          }
+          out_unknown_idx = i;
+        } else {
+          TF_RETURN_IF_ERROR(
+              c->Multiply(known_out_elems, dim, &known_out_elems));
         }
-        unknown_idx = i;
-      } else {
-        TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems));
       }
     }
-    if (!too_many_unknown && c->Value(known_elems) != 0) {
-      DimensionHandle inferred_dim;
-      TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems),
-                                   true /* evenly_divisible */, &inferred_dim));
-      TF_RETURN_IF_ERROR(c->ReplaceDim(out, unknown_idx, inferred_dim, &out));
+    int32 in_unknown_idx = -1;
+    DimensionHandle known_in_elems = c->NumElements(in);
+    if (!c->ValueKnown(known_in_elems)) {
+      known_in_elems = c->MakeDim(1);
+      for (int32 i = 0; i < c->Rank(in); ++i) {
+        DimensionHandle dim = c->Dim(in, i);
+        if (!c->ValueKnown(dim)) {
+          if (in_unknown_idx >= 0) {
+            too_many_unknown = true;
+            break;
+          }
+          in_unknown_idx = i;
+        } else {
+          TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
+        }
+      }
     }
-  }
 
+    if (!too_many_unknown) {
+      if (in_unknown_idx < 0 && out_unknown_idx < 0) {
+        // Just check that the dimensions match.
+        if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
+          return errors::InvalidArgument(
+              "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
+              " elements to shape ", c->DebugString(out), " (",
+              c->DebugString(known_out_elems), " elements)");
+        }
+      } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
+                 c->Value(known_out_elems) > 0) {
+        // Input fully known, infer the one missing output dim
+        DimensionHandle inferred_dim;
+        TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
+                                     true /* evenly_divisible */,
+                                     &inferred_dim));
+        TF_RETURN_IF_ERROR(
+            c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
+
+      } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
+                 c->Value(known_in_elems) != 0) {
+        // Output fully known, infer the one missing input dim
+        DimensionHandle inferred_dim;
+        TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
+                                     true /* evenly_divisible */,
+                                     &inferred_dim));
+        DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+        TF_RETURN_IF_ERROR(
+            c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
+      } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
+        // Exactly one unknown dimension in both input and output. These 2 are
+        // equal iff the known elements are equal.
+        if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
+          DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+          TF_RETURN_IF_ERROR(
+              c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
+        }
+      }
+    }
+  }
   c->set_output(0, out);
   return Status::OK();
 }
index cf5bb5a..b146333 100644 (file)
@@ -838,7 +838,7 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) {
   // Unknown dimensions.
   // Flatten:
   new_shape = test::AsTensor<int32>({-1});
-  INFER_OK(op, "[?];[1]", "[?]");
+  INFER_OK(op, "[?];[1]", "[d0_0]");
   INFER_OK(op, "[2,2];[1]", "[4]");
   // The first dimension is inferred:
   new_shape = test::AsTensor<int32>({2, -1});
@@ -851,6 +851,10 @@ TEST(ArrayOpsTest, Reshape_ShapeFn) {
   new_shape = test::AsTensor<int32>({-1, -1, 2});
   INFER_OK(op, "[8];[3]", "[?,?,2]");
 
+  // Symbolic shape propagation
+  new_shape = test::AsTensor<int32>({-1, 2, 3});
+  INFER_OK(op, "[?,2,3];[3]", "[d0_0,2,3]");
+
   // Reshaping to a scalar.
   new_shape = test::AsTensor<int32>({});
   INFER_OK(op, "[1];[0]", "[]");