From 395428bcaf02c9a9e8067083993d7e6b5afdc0a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 14:01:45 -0700 Subject: [PATCH] Move RemodeRedundantReshape optimization to a separate stage. PiperOrigin-RevId: 198775276 --- .../grappler/optimizers/arithmetic_optimizer.cc | 114 +++++++++++---------- .../grappler/optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 90 ++++++++-------- 3 files changed, 111 insertions(+), 94 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index e7f385c..0edea16 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -196,22 +196,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) { bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } -// Returns whether `reshape` is an identity op. The tensor that `reshape` -// reshapes is the `output_pos`-th output of node `input`. -bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, - const int output_pos, - const GraphProperties& graph_properties) { - const std::vector& reshape_props = - graph_properties.GetOutputProperties(reshape.name()); - const std::vector& input_props = - graph_properties.GetOutputProperties(input.name()); - if (reshape_props.empty() || input_props.size() <= output_pos) { - return false; - } - - return ShapesSymbolicallyEqual(input_props[output_pos], reshape_props[0]); -} - NodeDef* GetTailOfValuePreservingChain( const NodeDef& node, const NodeMap& node_map, const std::unordered_set& nodes_to_preserve) { @@ -1823,6 +1807,65 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { } }; +// Bypass redundant reshape nodes: +// +// Reshape Reshape <-+ +// ^ | +// | | +// Reshape becomes Reshape | +// ^ | +// | | +// input input ---+ +class RemoveRedundantReshape : public ArithmeticOptimizerStage { + public: + explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {} + ~RemoveRedundantReshape() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsReshape(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + + // 1. Bypass reshape followed by reshape. + if (IsReshape(*input) && !HasControlInputs(*input)) { + node->set_input(0, input->input(0)); + ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0)); + *simplified_node_name = node->name(); + AddToOptimizationQueue(node); + return Status::OK(); + } + + // 2. If the reshape is a no-op, forward its input to its consumers, unless + // it anchors a control dependency since we want to make sure that control + // dependency is triggered. + if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) { + *simplified_node_name = node->input(0); + return Status::OK(); + } + + return Status::OK(); + } + + private: + // Returns whether `reshape` is an identity op. + bool ReshapeIsIdentity(const NodeDef& reshape) { + OpInfo::TensorProperties reshape_props; + OpInfo::TensorProperties input_props; + + if (!GetTensorProperties(reshape.name(), &reshape_props).ok() || + !GetTensorProperties(reshape.input(0), &input_props).ok()) { + return false; + } + + return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape()); + } +}; + } // namespace class UniqueNodes { @@ -2076,43 +2119,6 @@ void ArithmeticOptimizer::ForwardControlDependencies( string ArithmeticOptimizer::TrySimplifyAndReplaceUses( const NodeDef* node, SetVector* nodes_to_simplify) { - if (node->op() == "Reshape") { - // Reshape - // ^ - // | - // Reshape - // ^ - // | - // input - // - // becomes - // - // Reshape <-+ - // | - // Reshape | - // ^ | - // | | - // input ---+ - NodeDef* reshape = const_cast(node); - int output_pos = 0; - string input_node_name = ParseNodeName(reshape->input(0), &output_pos); - const NodeDef* input = node_map_->GetNode(input_node_name); - if (input->op() == "Reshape" && !HasControlInputs(*input)) { - reshape->set_input(0, input->input(0)); - node_map_->UpdateInput(reshape->name(), input->name(), input->input(0)); - nodes_to_simplify->PushBack(reshape); - return reshape->name(); - } - - // If the reshape is a no-op, forward its input to its consumers, unless it - // anchors a control dependency since we want to make sure that control - // dependency is triggered. - if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) && - !HasControlInputs(*reshape)) { - return reshape->input(0); - } - } - if (node->op() == "Transpose") { // Reorder Cast and Transpose if beneficial. // @@ -2450,6 +2456,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.remove_redundant_cast) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_redundant_reshape) + pipeline.AddStage(ctx, ctx_ext); if (options_.remove_negation) pipeline.AddStage(ctx, ctx_ext); if (options_.remove_logical_not) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 9623991..9f8ec85 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -71,6 +71,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_negation = true; bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; + bool remove_redundant_reshape = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f678ea7..43355ef 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -124,6 +124,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.remove_idempotent = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; options.remove_negation = false; options.remove_logical_not = false; optimizer->options_ = options; @@ -168,6 +169,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_redundant_cast = true; } + void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_reshape = true; + } + void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_negation = true; @@ -955,7 +961,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28})); @@ -977,11 +983,11 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); @@ -989,7 +995,8 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1})); @@ -1009,27 +1016,28 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshapeBetweenSymbolicShapes) { Output reshape = ops::Reshape(s, inputs, target_shape); Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); + auto x_t = GenerateRandomTensor(TensorShape({3, 3, 28, 28})); GrapplerItem item; item.fetch = {"outputs"}; + item.feed = {{"Placeholder", x_t}}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto x_t = GenerateRandomTensor(TensorShape({3, 3, 28, 28})); - auto tensors_expected = - EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}}); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + // Assume valid feed shape in aggressive mode. + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); - auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}}); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1047,10 +1055,9 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { EXPECT_EQ(1, tensors_expected.size()); GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); // The reshape is preserved because the shape of the placeholder can be // different from the shape of the actual feed. @@ -1061,7 +1068,8 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28})); @@ -1077,12 +1085,11 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE) - .Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(0, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1090,7 +1097,7 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) { // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can // be from [4,3,28,28] to [8,6,28,28]. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1106,11 +1113,11 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { item.feed = {{"Placeholder", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); @@ -1118,7 +1125,8 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { +TEST_F(ArithmeticOptimizerTest, + RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output inputs = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3})); @@ -1128,16 +1136,16 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { GrapplerItem item; item.fetch = {"outputs"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); } -TEST_F(ArithmeticOptimizerTest, CombineReshapes) { +TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) { // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two // reshapes should be combined. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -1162,11 +1170,11 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) { item.feed = {{"nchw_vect_c", x_t}}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); EXPECT_EQ(1, tensors_expected.size()); - GraphDef output; - TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); - item.graph.Swap(&output); - TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveRedundantReshape(&optimizer); + OptimizeTwiceAndPrune(&optimizer, &item, &output); EXPECT_EQ(1, CountOpNodes(output, "Reshape")); auto tensors = EvaluateNodes(output, item.fetch, item.feed); -- 2.7.4