Extracts the SimplifyReduction optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 01:13:23 +0000 (18:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 01:16:07 +0000 (18:16 -0700)
PiperOrigin-RevId: 197823183

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 4ebe1ca..bf606fb 100644 (file)
@@ -2133,20 +2133,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
       return Status::OK();
     }
   }
-  if (IsSimplifiableReduction(*node, *properties)) {
-    // Replace the reduction node with an identity node, that can be further
-    // optimized by the model pruner.
-    DataType output_type;
-    if (node->attr().count("T") > 0) {
-      output_type = node->attr().at("T").type();
-    } else {
-      // This is an 'any' or 'all' reduction. The output is always boolean.
-      output_type = DT_BOOL;
-    }
-    node->set_op("Identity");
-    node->clear_attr();
-    (*node->mutable_attr())["T"].set_type(output_type);
-    *node->mutable_input(1) = AsControlDependency(node->input(1));
+
+  if (SimplifyReduction(*properties, node)) {
     graph_modified_ = true;
     return Status::OK();
   }
@@ -2200,6 +2188,27 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+bool ConstantFolding::SimplifyReduction(const GraphProperties& properties,
+                                        NodeDef* node) {
+  if (IsSimplifiableReduction(*node, properties)) {
+    // Replace the reduction node with an identity node, that can be further
+    // optimized by the model pruner.
+    DataType output_type;
+    if (node->attr().count("T") > 0) {
+      output_type = node->attr().at("T").type();
+    } else {
+      // This is an 'any' or 'all' reduction. The output is always boolean.
+      output_type = DT_BOOL;
+    }
+    node->set_op("Identity");
+    node->clear_attr();
+    (*node->mutable_attr())["T"].set_type(output_type);
+    *node->mutable_input(1) = AsControlDependency(node->input(1));
+    return true;
+  }
+  return false;
+}
+
 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
                                       bool use_shape_info, NodeDef* node) {
   if (!use_shape_info) return false;
index 3cf379f..07a2e01 100644 (file)
@@ -139,11 +139,13 @@ class ConstantFolding : public GraphOptimizer {
                                       GraphDef* optimized_graph, NodeDef* node,
                                       bool* success);
 
-  // Simplifies a Reshape operation to an Identity operation if the input node
-  // to the operation is a constant.
+  // Simplifies a Reshape operation to an Identity operation if applicable.
   bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
                        NodeDef* node);
 
+  // Simplifies a Reduction operation to an Identity operation if applicable.
+  bool SimplifyReduction(const GraphProperties& properties, NodeDef* node);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;