From e613e0844a95814457f3530eedb9baf812cf1e87 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Mar 2018 14:27:06 -0700 Subject: [PATCH] Enable stack push removal optimization by default. PiperOrigin-RevId: 189641729 --- .../core/grappler/optimizers/loop_optimizer.cc | 36 +++++++++++++--------- .../grappler/optimizers/loop_optimizer_test.cc | 24 ++++++++++----- .../core/grappler/optimizers/meta_optimizer.cc | 4 +-- tensorflow/core/protobuf/rewriter_config.proto | 2 +- tensorflow/python/kernel_tests/BUILD | 2 ++ 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 2446535..f78036d 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -45,8 +45,9 @@ namespace tensorflow { namespace grappler { namespace { -std::vector GetStackPushNodesToConvert(const SimpleGraphView& graph_view, - int stack_node_idx) { +std::vector GetStackPushNodesToConvert( + const SimpleGraphView& graph_view, + const std::unordered_set& nodes_to_preserve, int stack_node_idx) { VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name(); const std::unordered_set op_types_to_traverse( {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch", @@ -64,7 +65,9 @@ std::vector GetStackPushNodesToConvert(const SimpleGraphView& graph_view, op_types_to_traverse.end()) { continue; } else if (!IsStackPopOp(fanout_node) || - !graph_view.outputs(fanout_idx).empty()) { + (!graph_view.outputs(fanout_idx).empty() || + nodes_to_preserve.find(fanout_node.name()) != + nodes_to_preserve.end())) { // The node is either a stack pop with consumers or something unexpected // so we leave the graph alone. nodes_to_convert.clear(); @@ -74,15 +77,17 @@ std::vector GetStackPushNodesToConvert(const SimpleGraphView& graph_view, return nodes_to_convert; } -Status RemoveStackOps(const GraphDef& graph, GraphDef* optimized_graph) { +Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) { + const std::unordered_set nodes_to_preserve = item.NodesToPreserve(); + const GraphDef& graph = item.graph; *optimized_graph = graph; NodeMap node_map(optimized_graph); SimpleGraphView graph_view; TF_RETURN_IF_ERROR(graph_view.Initialize(graph)); for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) { if (IsStackOp(graph.node(node_idx))) { - for (int push_node_idx : - GetStackPushNodesToConvert(graph_view, node_idx)) { + for (int push_node_idx : GetStackPushNodesToConvert( + graph_view, nodes_to_preserve, node_idx)) { // We found push nodes without corresponding pops. Convert them to // Identity passing the data through and add a control dependency from // the op supplying the stack handle. @@ -463,17 +468,18 @@ Status LoopOptimizer::LoopInvariantNodeMotion() { Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { - TF_RETURN_IF_ERROR(RemoveStackOps(item.graph, optimized_graph)); + TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph)); - optimized_graph_ = optimized_graph; - - // Set up helper data structures. - node_map_.reset(new NodeMap(optimized_graph_)); - int num_frames; - TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, - &frame_map_, &num_frames)); + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + optimized_graph_ = optimized_graph; + // Set up helper data structures. + node_map_.reset(new NodeMap(optimized_graph_)); + int num_frames; + TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_, + &frame_map_, &num_frames)); + TF_RETURN_IF_ERROR(LoopInvariantNodeMotion()); + } - TF_RETURN_IF_ERROR(LoopInvariantNodeMotion()); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 0d45ba9..a0bd335 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -81,7 +81,7 @@ TEST_F(LoopOptimizerTest, Basic) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -128,7 +128,7 @@ TEST_F(LoopOptimizerTest, Const) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -175,7 +175,7 @@ TEST_F(LoopOptimizerTest, ControlOutput) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -235,7 +235,7 @@ TEST_F(LoopOptimizerTest, NestedLoop1) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -302,7 +302,7 @@ TEST_F(LoopOptimizerTest, NestedLoop2) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -365,7 +365,7 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -429,7 +429,7 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) { GrapplerItem item; item.graph = graph; - LoopOptimizer optimizer; + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -502,6 +502,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) { AddSimpleNode("stack3", "StackV2", {}, &graph); AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph); AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph); + LoopOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); @@ -525,12 +526,19 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { AddSimpleNode("stack3", "StackV2", {}, &graph); AddSimpleNode("push3", "StackPushV2", {"stack3", "c"}, &graph); AddSimpleNode("pop3", "StackPopV2", {"stack3"}, &graph); + // Push for a Pop without consumer that is fetched should not be removed. + AddSimpleNode("stack4", "StackV2", {}, &graph); + AddSimpleNode("push4", "StackPushV2", {"stack4", "c"}, &graph); + AddSimpleNode("pop4", "StackPopV2", {"stack4"}, &graph); + + item.fetch.push_back("pop4"); LoopOptimizer optimizer; GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(10, output.node_size()); + + EXPECT_EQ(13, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); if (node.name() == "push1") { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 7b2e7a1..6eb2bbc 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -106,7 +106,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr( new ArithmeticOptimizer(cfg_.arithmetic_optimization()))); } - if (cfg_.loop_optimization() == RewriterConfig::ON) { + if (cfg_.loop_optimization() != RewriterConfig::OFF) { optimizers.push_back(std::unique_ptr( new LoopOptimizer(cfg_.loop_optimization()))); } @@ -234,7 +234,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.function_optimization() != RewriterConfig::OFF || cfg.constant_folding() != RewriterConfig::OFF || cfg.arithmetic_optimization() != RewriterConfig::OFF || - cfg.loop_optimization() == RewriterConfig::ON || + cfg.loop_optimization() != RewriterConfig::OFF || cfg.dependency_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index b1fceaa..fdf16aa 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -42,7 +42,7 @@ message RewriterConfig { // Control dependency optimizations (default is ON). // Remove redundant control dependencies, which may enable other optimization. Toggle dependency_optimization = 8; - // Loop optimizations (default is OFF). + // Loop optimizations (default is ON). Toggle loop_optimization = 9; // Function optimizations (default is ON). Toggle function_optimization = 10; diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 5b0c38f..d9571fa 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -393,6 +393,7 @@ tf_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", ], + shard_count = 5, ) tf_py_test( @@ -408,6 +409,7 @@ tf_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", ], + shard_count = 5, ) tf_py_test( -- 2.7.4