Extracts the 'simplify squeeze node' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 22:53:44 +0000 (15:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 22:56:30 +0000 (15:56 -0700)
PiperOrigin-RevId: 197968452

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index b8b8088..3b56f10 100644 (file)
@@ -1937,22 +1937,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     }
   }
 
-  if (use_shape_info && IsSqueeze(*node) &&
-      !properties->GetInputProperties(node->name()).empty()) {
-    // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
-    // error to squeeze a dimension that is not 1, so we only need to check
-    // whether the input has > 1 size for each dimension.
-    const auto& shape = properties->GetInputProperties(node->name())[0].shape();
-    // The node is replaceable iff
-    // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
-    bool replaceable = !shape.unknown_rank();
-    for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-      replaceable &= shape.dim(j).size() > 1;
-    }
-    if (replaceable) {
-      ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-      return Status::OK();
-    }
+  if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
+    return Status::OK();
   }
 
   if (SimplifyPack(optimized_graph, node)) {
@@ -2024,6 +2010,30 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
+                                      bool use_shape_info,
+                                      GraphDef* optimized_graph,
+                                      NodeDef* node) {
+  if (use_shape_info && IsSqueeze(*node) &&
+      !properties.GetInputProperties(node->name()).empty()) {
+    // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
+    // error to squeeze a dimension that is not 1, so we only need to check
+    // whether the input has > 1 size for each dimension.
+    const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+    // The node is replaceable iff
+    // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
+    bool replaceable = !shape.unknown_rank();
+    for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+      replaceable &= shape.dim(j).size() > 1;
+    }
+    if (replaceable) {
+      ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+      return true;
+    }
+  }
+  return false;
+}
+
 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
   if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
       !OptimizedNodeExists(*node, "_const_axis")) {
index be78004..55ad686 100644 (file)
@@ -170,6 +170,10 @@ class ConstantFolding : public GraphOptimizer {
   // Simplifies Pack operation if applicable.
   bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
 
+  // Simplifies a Squeeze operation to an Identity operation if applicable.
+  bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
+                       GraphDef* optimized_graph, NodeDef* node);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;