Improve ReshapeIsIdentity to work with symbolic shapes.
authorJingyue Wu <jingyue@google.com>
Thu, 31 May 2018 05:00:32 +0000 (22:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 05:03:11 +0000 (22:03 -0700)
For example, with this CL, ArithmeticOptimizer can optimize the Reshape below
into a no-op.

  s = Shape(t)
  Reshape(t, Concat(s[0], s[1], s[2], s[3]))

PiperOrigin-RevId: 198668726

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index 9c18c45..e7f385c 100644 (file)
@@ -209,40 +209,7 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
     return false;
   }
 
-  const PartialTensorShape& src_shape = input_props[output_pos].shape();
-  const PartialTensorShape& dst_shape = reshape_props[0].shape();
-
-  if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
-    return false;
-  }
-
-  if (!dst_shape.IsCompatibleWith(src_shape)) {
-    return false;
-  }
-
-  // Returns false when src_shape or dst_shape has >=2 dimensions with unknown
-  // sizes.
-  auto num_unknown_dim_sizes = [](const PartialTensorShape& partial_shape) {
-    auto dim_sizes = partial_shape.dim_sizes();
-    return std::count_if(dim_sizes.begin(), dim_sizes.end(),
-                         [](int dim) { return dim < 0; });
-  };
-  int src_num_unknown_dim_sizes = num_unknown_dim_sizes(src_shape);
-  int dst_num_unknown_dim_sizes = num_unknown_dim_sizes(dst_shape);
-  if (src_num_unknown_dim_sizes > 1 || dst_num_unknown_dim_sizes > 1) {
-    return false;
-  }
-
-  // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken
-  // shape inference in subsequent passes if we removed this reshape.
-  if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) {
-    return false;
-  }
-
-  // Remove the reshape if both are fully defined or partially defined and the
-  // unknown or symbolic shape appears on the same dimension, i.e., if
-  // IsIdenticalTo returns true.
-  return dst_shape.IsIdenticalTo(src_shape);
+  return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]);
 }
 
 NodeDef* GetTailOfValuePreservingChain(
index a908416..f678ea7 100644 (file)
@@ -989,6 +989,46 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
+TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output inputs =
+      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
+  Output inputs_shape = ops::Shape(s, inputs);
+  // The target shape of the reshape is the concatenation of `batch_size`, 3,
+  // `height, and `width`.
+  Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
+                                 ops::Const(s, {1}, {1}));
+  Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
+                             ops::Const(s, {1}, {1}));
+  Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
+                            ops::Const(s, {1}, {1}));
+  Output target_shape =
+      ops::Concat(s.WithOpName("target_shape"),
+                  {batch_size, ops::Const(s, {3}, {1}), height, width},
+                  ops::Const(s, {0}, {}));
+  Output reshape = ops::Reshape(s, inputs, target_shape);
+  Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+  GrapplerItem item;
+  item.fetch = {"outputs"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+  auto tensors_expected =
+      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+  EXPECT_EQ(1, tensors_expected.size());
+  GraphDef output;
+  TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
+                   .Optimize(nullptr, item, &output));
+
+  item.graph.Swap(&output);
+  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+
+  EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+  EXPECT_EQ(1, tensors.size());
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
 TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =