Enable aggressive identity node pruning in dependency optimizer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 5 Feb 2018 17:33:42 +0000 (09:33 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 17:36:53 +0000 (09:36 -0800)
PiperOrigin-RevId: 184539756

tensorflow/core/grappler/optimizers/dependency_optimizer.cc
tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
tensorflow/python/grappler/cluster_test.py

index 7b4ca14..db64e53 100644 (file)
@@ -156,7 +156,6 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
 void DependencyOptimizer::OptimizeNode(int node_idx,
                                        SetVector<int>* nodes_to_simplify,
                                        std::set<int>* nodes_to_delete) {
-  const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
   NodeDef* node = optimized_graph_->mutable_node(node_idx);
   const bool is_noop = IsNoOp(*node);
   const bool is_identity = IsIdentity(*node);
@@ -280,7 +279,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
   //    y --^> |          | --^> b       /\    +---+
   //           +----------+             y --^> b
 
-  if (is_noop || (is_identity && is_aggressive)) {
+  if (is_noop || is_identity) {
     const auto& output_node_set = node_map_->GetOutputs(node_name);
     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
                                              output_node_set.end());
index b8facb9..33d6b99 100644 (file)
@@ -176,7 +176,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
   item.fetch.push_back("neg1");
   item.fetch.push_back("neg2");
 
-  DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  DependencyOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -360,7 +360,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch = {"a_a", "a_b", "a_c", "a_d", "b_a", "c_a", "c_b"};
 
-  DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  DependencyOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -420,7 +420,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
   item.fetch.push_back("or0");
   item.fetch.push_back("or1");
   item.fetch.push_back("or2");
-  DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  DependencyOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -459,7 +459,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
   item.fetch.push_back("neg2");
-  DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  DependencyOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -495,7 +495,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
   item.fetch.push_back("id2");
   item.fetch.push_back("fetch");
 
-  DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  DependencyOptimizer optimizer;
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
index 2292b2c..10d515a 100644 (file)
@@ -45,7 +45,7 @@ class ClusterTest(test.TestCase):
       op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
           grappler_item)
       self.assertTrue(run_time > 0)
-      self.assertEqual(len(op_perfs), 9)
+      self.assertEqual(len(op_perfs), 7)
       self.assertTrue(step_stats.dev_stats)
 
   def testNoDetailedStats(self):
@@ -125,7 +125,7 @@ class ClusterTest(test.TestCase):
         disable_detailed_stats=False, disable_timeline=False) as gcluster:
       op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
       self.assertTrue(run_time > 0)
-      self.assertEqual(len(op_perfs), 9)
+      self.assertEqual(len(op_perfs), 7)
       self.assertTrue(step_stats.dev_stats)
 
   def testAvailableOps(self):