Move RemodeRedundantReshape optimization to a separate stage.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 31 May 2018 21:01:45 +0000 (14:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 21:03:45 +0000 (14:03 -0700)
PiperOrigin-RevId: 198775276

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index e7f385c..0edea16 100644 (file)
@@ -196,22 +196,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
 
 bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
 
-// Returns whether `reshape` is an identity op. The tensor that `reshape`
-// reshapes is the `output_pos`-th output of node `input`.
-bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
-                       const int output_pos,
-                       const GraphProperties& graph_properties) {
-  const std::vector<OpInfo::TensorProperties>& reshape_props =
-      graph_properties.GetOutputProperties(reshape.name());
-  const std::vector<OpInfo::TensorProperties>& input_props =
-      graph_properties.GetOutputProperties(input.name());
-  if (reshape_props.empty() || input_props.size() <= output_pos) {
-    return false;
-  }
-
-  return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]);
-}
-
 NodeDef* GetTailOfValuePreservingChain(
     const NodeDef& node, const NodeMap& node_map,
     const std::unordered_set<string>& nodes_to_preserve) {
@@ -1823,6 +1807,65 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
   }
 };
 
+// Bypass redundant reshape nodes:
+//
+//   Reshape                    Reshape  <-+
+//      ^                                  |
+//      |                                  |
+//   Reshape       becomes      Reshape    |
+//      ^                                  |
+//      |                                  |
+//    input                      input  ---+
+class RemoveRedundantReshape : public ArithmeticOptimizerStage {
+ public:
+  explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
+                                  const ArithmeticOptimizerContext& ctx_ext)
+      : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
+  ~RemoveRedundantReshape() override = default;
+
+  bool IsSupported(const NodeDef* node) const override {
+    return IsReshape(*node);
+  }
+
+  Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+    NodeDef* input;
+    TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+    // 1. Bypass reshape followed by reshape.
+    if (IsReshape(*input) && !HasControlInputs(*input)) {
+      node->set_input(0, input->input(0));
+      ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
+      *simplified_node_name = node->name();
+      AddToOptimizationQueue(node);
+      return Status::OK();
+    }
+
+    // 2. If the reshape is a no-op, forward its input to its consumers, unless
+    // it anchors a control dependency since we want to make sure that control
+    // dependency is triggered.
+    if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
+      *simplified_node_name = node->input(0);
+      return Status::OK();
+    }
+
+    return Status::OK();
+  }
+
+ private:
+  // Returns whether `reshape` is an identity op.
+  bool ReshapeIsIdentity(const NodeDef& reshape) {
+    OpInfo::TensorProperties reshape_props;
+    OpInfo::TensorProperties input_props;
+
+    if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
+        !GetTensorProperties(reshape.input(0), &input_props).ok()) {
+      return false;
+    }
+
+    return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
+  }
+};
+
 }  // namespace
 
 class UniqueNodes {
@@ -2076,43 +2119,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
 string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
 
-  if (node->op() == "Reshape") {
-    //   Reshape
-    //      ^
-    //      |
-    //   Reshape
-    //      ^
-    //      |
-    //    input
-    //
-    // becomes
-    //
-    //   Reshape <-+
-    //             |
-    //   Reshape   |
-    //      ^      |
-    //      |      |
-    //    input ---+
-    NodeDef* reshape = const_cast<NodeDef*>(node);
-    int output_pos = 0;
-    string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
-    const NodeDef* input = node_map_->GetNode(input_node_name);
-    if (input->op() == "Reshape" && !HasControlInputs(*input)) {
-      reshape->set_input(0, input->input(0));
-      node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
-      nodes_to_simplify->PushBack(reshape);
-      return reshape->name();
-    }
-
-    // If the reshape is a no-op, forward its input to its consumers, unless it
-    // anchors a control dependency since we want to make sure that control
-    // dependency is triggered.
-    if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
-        !HasControlInputs(*reshape)) {
-      return reshape->input(0);
-    }
-  }
-
   if (node->op() == "Transpose") {
     // Reorder Cast and Transpose if beneficial.
     //
@@ -2450,6 +2456,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
   if (options_.remove_redundant_cast)
     pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+  if (options_.remove_redundant_reshape)
+    pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
   if (options_.remove_negation)
     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
   if (options_.remove_logical_not)
index 9623991..9f8ec85 100644 (file)
@@ -71,6 +71,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
     bool remove_negation = true;
     bool remove_redundant_bitcast = true;
     bool remove_redundant_cast = true;
+    bool remove_redundant_reshape = true;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
     // optimization level by default.
index f678ea7..43355ef 100644 (file)
@@ -124,6 +124,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     options.remove_idempotent = false;
     options.remove_redundant_bitcast = false;
     options.remove_redundant_cast = false;
+    options.remove_redundant_reshape = false;
     options.remove_negation = false;
     options.remove_logical_not = false;
     optimizer->options_ = options;
@@ -168,6 +169,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     optimizer->options_.remove_redundant_cast = true;
   }
 
+  void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.remove_redundant_reshape = true;
+  }
+
   void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
     optimizer->options_.remove_negation = true;
@@ -955,7 +961,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
   test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =
       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
@@ -977,11 +983,11 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
   auto tensors_expected =
       EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
   EXPECT_EQ(1, tensors_expected.size());
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
   auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
@@ -989,7 +995,8 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
+TEST_F(ArithmeticOptimizerTest,
+       RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =
       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
@@ -1009,27 +1016,28 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) {
   Output reshape = ops::Reshape(s, inputs, target_shape);
   Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
 
+  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
   GrapplerItem item;
   item.fetch = {"outputs"};
+  item.feed = {{"Placeholder", x_t}};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
-  auto tensors_expected =
-      EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
+
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
   EXPECT_EQ(1, tensors_expected.size());
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
-                   .Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  // Assume valid feed shape in aggressive mode.
+  ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
-  auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
+  auto tensors = EvaluateNodes(output, item.fetch, item.feed);
   EXPECT_EQ(1, tensors.size());
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =
       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
@@ -1047,10 +1055,9 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
   EXPECT_EQ(1, tensors_expected.size());
 
   GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   // The reshape is preserved because the shape of the placeholder can be
   // different from the shape of the actual feed.
@@ -1061,7 +1068,8 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
+TEST_F(ArithmeticOptimizerTest,
+       RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =
       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
@@ -1077,12 +1085,11 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
 
   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
   EXPECT_EQ(1, tensors_expected.size());
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
-                   .Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
@@ -1090,7 +1097,7 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) {
   // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
   // be from [4,3,28,28] to [8,6,28,28].
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -1106,11 +1113,11 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
   item.feed = {{"Placeholder", x_t}};
   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
   EXPECT_EQ(1, tensors_expected.size());
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
   auto tensors = EvaluateNodes(output, item.fetch, item.feed);
@@ -1118,7 +1125,8 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
 }
 
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
+TEST_F(ArithmeticOptimizerTest,
+       RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output inputs =
       ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
@@ -1128,16 +1136,16 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
   GrapplerItem item;
   item.fetch = {"outputs"};
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
 }
 
-TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) {
   // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
   // reshapes should be combined.
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -1162,11 +1170,11 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
   item.feed = {{"nchw_vect_c", x_t}};
   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
   EXPECT_EQ(1, tensors_expected.size());
-  GraphDef output;
-  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
 
-  item.graph.Swap(&output);
-  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyRemoveRedundantReshape(&optimizer);
+  OptimizeTwiceAndPrune(&optimizer, &item, &output);
 
   EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
   auto tensors = EvaluateNodes(output, item.fetch, item.feed);