Turn on dead branch elimination, shape optimization, and remapping by default
authorBenoit Steiner <bsteiner@google.com>
Mon, 21 May 2018 19:43:52 +0000 (12:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 21 May 2018 19:46:26 +0000 (12:46 -0700)
PiperOrigin-RevId: 197439191

tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
tensorflow/core/grappler/optimizers/loop_optimizer.cc
tensorflow/core/grappler/optimizers/meta_optimizer.cc
tensorflow/core/protobuf/rewriter_config.proto

index 95093be..c21a1ea 100644 (file)
@@ -102,9 +102,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
         EXPECT_EQ(2, shape.dim(0).size());
         EXPECT_EQ(1, shape.dim(1).size());
         if (node->name() == y->name()) {
-          EXPECT_EQ(13, cm->AllocationId(node, 0));
+          EXPECT_EQ(21, cm->AllocationId(node, 0));
         } else {
-          EXPECT_EQ(14, cm->AllocationId(node, 0));
+          EXPECT_EQ(22, cm->AllocationId(node, 0));
         }
       }
       EXPECT_LE(0, cm->MaxExecutionTime(node));
index bfef9a6..9627ed7 100644 (file)
@@ -597,6 +597,10 @@ Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
             }
           }
         }
+      } else if (dead.node->op() == "ControlTrigger") {
+        // Control trigger have different semantic, so don't touch them
+        found_node_to_preserve = true;
+        break;
       } else {
         if (local_dead_nodes.insert(dead.node).second) {
           for (const GraphView::InputPort& dead_fanout :
@@ -651,8 +655,7 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   if (options_.enable_stack_push_removal) {
     TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
   }
-  if (opt_level_ == RewriterConfig::AGGRESSIVE &&
-      options_.enable_dead_branch_removal) {
+  if (options_.enable_dead_branch_removal) {
     TF_RETURN_IF_ERROR(
         RemoveDeadBranches(item.NodesToPreserve(), optimized_graph));
   }
index 1ce2721..a927275 100644 (file)
@@ -110,10 +110,10 @@ Status MetaOptimizer::InitializeOptimizers(
     optimizers->emplace_back(
         new ConstantFolding(cfg_.constant_folding(), cpu_device_));
   }
-  if (cfg_.shape_optimization() == RewriterConfig::ON) {
+  if (cfg_.shape_optimization() != RewriterConfig::OFF) {
     optimizers->emplace_back(new ShapeOptimizer());
   }
-  if (cfg_.remapping() == RewriterConfig::ON) {
+  if (cfg_.remapping() != RewriterConfig::OFF) {
     optimizers->emplace_back(new Remapper(cfg_.remapping()));
   }
   if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -353,8 +353,8 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
          cfg.layout_optimizer() != RewriterConfig::OFF ||
          cfg.function_optimization() != RewriterConfig::OFF ||
          cfg.constant_folding() != RewriterConfig::OFF ||
-         cfg.shape_optimization() == RewriterConfig::ON ||
-         cfg.remapping() == RewriterConfig::ON ||
+         cfg.shape_optimization() != RewriterConfig::OFF ||
+         cfg.remapping() != RewriterConfig::OFF ||
          cfg.arithmetic_optimization() != RewriterConfig::OFF ||
          cfg.loop_optimization() != RewriterConfig::OFF ||
          cfg.dependency_optimization() != RewriterConfig::OFF ||
index ed2ba1f..10bfe30 100644 (file)
@@ -46,10 +46,10 @@ message RewriterConfig {
   // Statically infer the value of tensors when possible, and materialize the
   // result using constants.
   Toggle constant_folding = 3;
-  // Shape optimizations (default is OFF)
+  // Shape optimizations (default is ON)
   // Simplify computations made on shapes.
   Toggle shape_optimization = 13;
-  // Remapping (default is OFF)
+  // Remapping (default is ON)
   // Remap subgraphs onto more efficient implementations.
   Toggle remapping = 14;
   // Arithmetic optimizations (default is ON)