Extracts the following optimizations into methods:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 14 May 2018 16:45:42 +0000 (09:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 14 May 2018 16:48:38 +0000 (09:48 -0700)
PartialConstPropThroughIdentityN
ConstantPushDown

PiperOrigin-RevId: 196520167

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 171d492..b2dcbf9 100644 (file)
@@ -2157,6 +2157,30 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     return Status::OK();
   }
 
+  if (ConstantPushDown(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConstPropThroughIdentityN(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  return Status::OK();
+}
+
+bool ConstantFolding::ConstantPushDown(NodeDef* node) {
   // Consider the transformation
   //
   //                      +                +       = parent
@@ -2178,22 +2202,22 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
   // division/multiplication.
   // Don't touch BiasAdd since they can't handle vectors as their first
   // inputs.
-  if (has_fetch_ && (IsAdd(*node) || is_mul) &&
+  if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
       NumNonControlInputs(*node) == 2) {
     NodeDef* left_child = node_map_->GetNode(node->input(0));
     NodeDef* right_child = node_map_->GetNode(node->input(1));
     // One child must be constant, and the other the same op as the parent.
     if (node->op() != left_child->op() && node->op() != right_child->op()) {
-      return Status::OK();
+      return false;
     }
     const bool left_child_is_constant = IsReallyConstant(*left_child);
     const bool right_child_is_constant = IsReallyConstant(*right_child);
     if (!left_child_is_constant && !right_child_is_constant) {
-      return Status::OK();
+      return false;
     }
     if (node->device() != left_child->device() ||
         node->device() != right_child->device()) {
-      return Status::OK();
+      return false;
     }
     NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
     NodeDef* const_child_node =
@@ -2203,7 +2227,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
         nodes_to_preserve_.find(op_child_node->name()) !=
             nodes_to_preserve_.end() ||
         NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
-      return Status::OK();
+      return false;
     }
 
     // Identify the nodes to swap.
@@ -2213,7 +2237,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
     if (left_leaf_is_constant && right_leaf_is_constant) {
       // Child is already foldable, leave it alone.
-      return Status::OK();
+      return false;
     }
     const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
     const int parent_const_input = left_child_is_constant ? 0 : 1;
@@ -2238,10 +2262,12 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
                            node->input(parent_const_input));
     std::swap(*node->mutable_input(parent_const_input),
               *op_child_node->mutable_input(non_const_leaf_input));
-    graph_modified_ = true;
-    return Status::OK();
+    return true;
   }
+  return false;
+}
 
+bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
   // Partial constant propagation through IdentityN.
   if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
     const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
@@ -2294,22 +2320,10 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
       for (NodeDef* consumer : consumers) {
         DedupControlInputs(consumer);
       }
-      graph_modified_ = true;
-      return Status::OK();
+      return true;
     }
   }
-
-  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  return Status::OK();
+  return false;
 }
 
 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
index f92f755..227caba 100644 (file)
@@ -113,6 +113,14 @@ class ConstantFolding : public GraphOptimizer {
   bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
                                   GraphProperties* properties, NodeDef* node);
 
+  // Applies partial constant propagation through IdentityN operator.
+  // Returns true if the transformation applied successfully.
+  bool PartialConstPropThroughIdentityN(NodeDef* node);
+
+  // Pushes down constants on '+' and '*' operators if applicable. Returns true
+  // the transformation applied successfully.
+  bool ConstantPushDown(NodeDef* node);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;