Extracts the 'remove reverse node' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 23:43:29 +0000 (16:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 23:46:28 +0000 (16:46 -0700)
PiperOrigin-RevId: 198122165

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index d38f5a9..df32d4a 100644 (file)
@@ -1695,7 +1695,6 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
       }
     }
   }
-
   // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
   if (use_shape_info && IsRandomShuffle(*node) &&
       !properties->GetInputProperties(node->name()).empty()) {
@@ -1709,47 +1708,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     }
   }
 
-  // Remove Reverse op over dimensions with size 1.
-  if (use_shape_info && node->op() == "ReverseV2" &&
-      properties->GetInputProperties(node->name()).size() >= 2) {
-    const auto& shape = properties->GetInputProperties(node->name())[0].shape();
-    if (shape.unknown_rank()) {
-      // Not optimizable.
-      return Status::OK();
-    }
-    const auto& a = properties->GetInputProperties(node->name())[1];
-    if (TensorShape::IsValid(a.shape()) && a.has_value()) {
-      Tensor axis(a.dtype(), a.shape());
-      if (!axis.FromProto(a.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       a.value().DebugString());
-      }
-      std::set<int> target_axes;
-      for (int j = 0; j < axis.NumElements(); ++j) {
-        // value of axis can be negative.
-        if (axis.dtype() == DT_INT64) {
-          target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
-                             shape.dim_size());
-        } else {
-          target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
-                             shape.dim_size());
-        }
-      }
-
-      // The node is replaceable iff
-      // unknown_rank == false &&
-      // (dim_size == 0 || all dims have size 1 ||
-      //  all dims with > 1 size are not in target_axes)
-      bool replaceable = !shape.unknown_rank();
-      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-        replaceable &= shape.dim(j).size() == 1 ||
-                       target_axes.find(j) == target_axes.end();
-      }
-      if (replaceable) {
-        ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-        return Status::OK();
-      }
-    }
+  bool remove_reverse_successful = false;
+  Status remove_reverse_status =
+      RemoveReverse(*properties, use_shape_info, optimized_graph, node,
+                    &remove_reverse_successful);
+  if (!remove_reverse_status.ok()) {
+    return remove_reverse_status;
+  } else if (remove_reverse_successful) {
+    return Status::OK();
   }
 
   bool simplify_slice_successful = false;
@@ -1865,6 +1831,56 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
+                                      bool use_shape_info,
+                                      GraphDef* optimized_graph, NodeDef* node,
+                                      bool* success) {
+  if (use_shape_info && node->op() == "ReverseV2" &&
+      properties.GetInputProperties(node->name()).size() >= 2) {
+    const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+    if (shape.unknown_rank()) {
+      // Not optimizable.
+      return Status::OK();
+    }
+    const auto& a = properties.GetInputProperties(node->name())[1];
+    if (TensorShape::IsValid(a.shape()) && a.has_value()) {
+      Tensor axis(a.dtype(), a.shape());
+      if (!axis.FromProto(a.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       a.value().DebugString());
+      }
+      std::set<int> target_axes;
+      for (int j = 0; j < axis.NumElements(); ++j) {
+        // value of axis can be negative.
+        if (axis.dtype() == DT_INT64) {
+          target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
+                             shape.dim_size());
+        } else {
+          target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
+                             shape.dim_size());
+        }
+      }
+
+      // The node is replaceable iff
+      // unknown_rank == false &&
+      // (dim_size == 0 || all dims have size 1 ||
+      //  all dims with > 1 size are not in target_axes)
+      bool replaceable = !shape.unknown_rank();
+      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+        replaceable &= shape.dim(j).size() == 1 ||
+                       target_axes.find(j) == target_axes.end();
+      }
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+        *success = true;
+        return Status::OK();
+      }
+    }
+  }
+  *success = false;
+  return Status::OK();
+}
+
 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
                                       bool use_shape_info,
                                       GraphDef* optimized_graph, NodeDef* node,
index 2da6395..9a3ea03 100644 (file)
@@ -190,6 +190,10 @@ class ConstantFolding : public GraphOptimizer {
   // Simplifies a Slice operation to an Identity operation if applicable.
   Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
                        GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+  // Removes Reverse op over dimensions with size 1.
+  Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
+                       GraphDef* optimized_graph, NodeDef* node, bool* success);
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;