Extracts the 'switch with same input' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 16:05:41 +0000 (09:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 16:08:12 +0000 (09:08 -0700)
PiperOrigin-RevId: 197900929

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index bf606fb..a71f83b 100644 (file)
@@ -2035,22 +2035,66 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     }
   }
 
-  // Switch(x, x) will always feed false to its false branch and true to
-  // its true branch. By rewriting the graph a bit, we can propagate these
-  // constants down the two output branches, and just use control dependencies
-  // to trigger the selected one at runtime. For example,
-  //
-  //     +------+
-  // x-->|Switch|-->a  (in practice there may be multiple consumers of each
-  // x-->|      |-->b   output branch.)
-  //     +------+
-  //
-  // Is rewritten as
-  //
-  //     +------+
-  // x-->|Switch|-->Identity--^>Const(false)-->a
-  // x-->|      |-->Identity--^>Const(true)-->b
-  //     +------+
+  if (SimplifySwitch(optimized_graph, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (SimplifyReduction(*properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (SimplifyReshape(*properties, use_shape_info, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  bool arithmetic_simplification_succeed = false;
+  Status simplify_arithmetic_status =
+      SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
+                                   node, &arithmetic_simplification_succeed);
+  if (!simplify_arithmetic_status.ok()) {
+    return simplify_arithmetic_status;
+  } else if (arithmetic_simplification_succeed) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (ReduceDivToReciprocalMul(optimized_graph, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (ConstantPushDown(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (MulConvPushDown(node, *properties)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConstPropThroughIdentityN(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  return Status::OK();
+}
+
+bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
   if (node->op() == "Switch" && node->input(0) == node->input(1) &&
       !OptimizedNodeExists(*node, "_const_false") &&
       !OptimizedNodeExists(*node, "_const_true")) {
@@ -2087,7 +2131,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
       false_node->set_name(OptimizedNodeName(*node, "_const_false"));
       if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
                .ok()) {
-        return Status::OK();
+        return false;
       }
       false_node->set_device(node->device());
 
@@ -2095,7 +2139,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
       true_node->set_name(OptimizedNodeName(*node, "_const_true"));
       if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
                .ok()) {
-        return Status::OK();
+        return false;
       }
       true_node->set_device(node->device());
 
@@ -2129,63 +2173,10 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
           }
         }
       }
-      graph_modified_ = true;
-      return Status::OK();
+      return true;
     }
   }
-
-  if (SimplifyReduction(*properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (SimplifyReshape(*properties, use_shape_info, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  bool arithmetic_simplification_succeed = false;
-  Status simplify_arithmetic_status =
-      SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
-                                   node, &arithmetic_simplification_succeed);
-  if (!simplify_arithmetic_status.ok()) {
-    return simplify_arithmetic_status;
-  } else if (arithmetic_simplification_succeed) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (ReduceDivToReciprocalMul(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (ConstantPushDown(node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (MulConvPushDown(node, *properties)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConstPropThroughIdentityN(node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  return Status::OK();
+  return false;
 }
 
 bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
index 07a2e01..88f03b3 100644 (file)
@@ -146,6 +146,24 @@ class ConstantFolding : public GraphOptimizer {
   // Simplifies a Reduction operation to an Identity operation if applicable.
   bool SimplifyReduction(const GraphProperties& properties, NodeDef* node);
 
+  // Switch(x, x) will always feed false to its false branch and true to
+  // its true branch. By rewriting the graph a bit, we can propagate these
+  // constants down the two output branches, and just use control dependencies
+  // to trigger the selected one at runtime. For example,
+  //
+  //     +------+
+  // x-->|Switch|-->a  (in practice there may be multiple consumers of each
+  // x-->|      |-->b   output branch.)
+  //     +------+
+  //
+  // Is rewritten as
+  //
+  //     +------+
+  // x-->|Switch|-->Identity--^>Const(false)-->a
+  // x-->|      |-->Identity--^>Const(true)-->b
+  //     +------+
+  bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;