Extracts the 'simplify pad node' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 01:55:30 +0000 (18:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 01:58:27 +0000 (18:58 -0700)
PiperOrigin-RevId: 197989813

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 3b56f10..8cd1968 100644 (file)
@@ -1913,28 +1913,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     }
   }
 
-  if (use_shape_info && IsPad(*node) &&
-      properties->GetInputProperties(node->name()).size() >= 2) {
-    const auto& p = properties->GetInputProperties(node->name())[1];
-    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
-      Tensor paddings(p.dtype(), p.shape());
-      if (!paddings.FromProto(p.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       p.value().DebugString());
-      }
-      // The node is replaceable iff all values in paddings are 0.
-      bool replaceable = true;
-      // The operation requires it to be int32 value so we don't check for
-      // 1nt64.
-      const auto flatten = paddings.flat<int32>();
-      for (int j = 0; replaceable && j < flatten.size(); ++j) {
-        replaceable &= flatten(j) == 0;
-      }
-      if (replaceable) {
-        ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-        return Status::OK();
-      }
-    }
+  bool simplify_pad_successful = false;
+  Status simplify_pad_status =
+      SimplifyPad(*properties, use_shape_info, optimized_graph, node,
+                  &simplify_pad_successful);
+  if (!simplify_pad_status.ok()) {
+    return simplify_pad_status;
+  } else if (simplify_pad_successful) {
+    return Status::OK();
   }
 
   if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
@@ -2010,6 +1996,38 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
+                                    bool use_shape_info,
+                                    GraphDef* optimized_graph, NodeDef* node,
+                                    bool* success) {
+  if (use_shape_info && IsPad(*node) &&
+      properties.GetInputProperties(node->name()).size() >= 2) {
+    const auto& p = properties.GetInputProperties(node->name())[1];
+    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+      Tensor paddings(p.dtype(), p.shape());
+      if (!paddings.FromProto(p.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       p.value().DebugString());
+      }
+      // The node is replaceable iff all values in paddings are 0.
+      bool replaceable = true;
+      // The operation requires it to be int32 value so we don't check for
+      // 1nt64.
+      const auto flatten = paddings.flat<int32>();
+      for (int j = 0; replaceable && j < flatten.size(); ++j) {
+        replaceable &= flatten(j) == 0;
+      }
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+        *success = true;
+        return Status::OK();
+      }
+    }
+  }
+  *success = false;
+  return Status::OK();
+}
+
 bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
                                       bool use_shape_info,
                                       GraphDef* optimized_graph,
index 55ad686..fa9249f 100644 (file)
@@ -174,6 +174,10 @@ class ConstantFolding : public GraphOptimizer {
   bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
                        GraphDef* optimized_graph, NodeDef* node);
 
+  // Simplifies a Pad operation to an Identity operation if applicable.
+  Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
+                     GraphDef* optimized_graph, NodeDef* node, bool* success);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;