Extracts the 'remove shuffle or transpose node' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 May 2018 18:12:22 +0000 (11:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 18:14:56 +0000 (11:14 -0700)
PiperOrigin-RevId: 198425354

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index fed5d87..90862b0 100644 (file)
@@ -1655,45 +1655,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     return Status::OK();
   }
 
-  // Remove Shuffle or Transpose op over dimensions of size 1.
-  if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
-      properties->GetInputProperties(node->name()).size() >= 2) {
-    const auto& shape = properties->GetInputProperties(node->name())[0].shape();
-    if (shape.unknown_rank()) {
-      // Not optimizable.
-      return Status::OK();
-    }
-    const auto& p = properties->GetInputProperties(node->name())[1];
-    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
-      Tensor perm(p.dtype(), p.shape());
-      if (!perm.FromProto(p.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       p.value().DebugString());
-      }
-      std::vector<int> permutation;
-      for (int j = 0; j < perm.NumElements(); ++j) {
-        if (perm.dtype() == DT_INT64) {
-          permutation.push_back(perm.vec<int64>()(j));
-        } else {
-          permutation.push_back(perm.vec<int>()(j));
-        }
-      }
-      if (permutation.size() != shape.dim_size()) {
-        // Number of elements in perm should be same as dim_size. Skip if not.
-        return Status::OK();
-      }
-      // The node is replaceable iff
-      // dim_size == 0 || all dims have size 1 ||
-      // all dims with > 1 size are not permuted.
-      bool replaceable = true;
-      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-        replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
-      }
-      if (replaceable) {
-        ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-        return Status::OK();
-      }
-    }
+  bool remove_shuffle_transpose_successful = false;
+  Status remove_shuffle_transpose_status =
+      RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
+                               node, &remove_shuffle_transpose_successful);
+  if (!remove_shuffle_transpose_status.ok()) {
+    return remove_shuffle_transpose_status;
+  } else if (remove_shuffle_transpose_successful) {
+    return Status::OK();
   }
 
   if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
@@ -1823,6 +1792,52 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+Status ConstantFolding::RemoveShuffleOrTranspose(
+    const GraphProperties& properties, bool use_shape_info,
+    GraphDef* optimized_graph, NodeDef* node, bool* success) {
+  if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
+      properties.GetInputProperties(node->name()).size() >= 2) {
+    const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+    if (shape.unknown_rank()) {
+      // Not optimizable.
+      return Status::OK();
+    }
+    const auto& p = properties.GetInputProperties(node->name())[1];
+    if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+      Tensor perm(p.dtype(), p.shape());
+      if (!perm.FromProto(p.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       p.value().DebugString());
+      }
+      std::vector<int> permutation;
+      for (int j = 0; j < perm.NumElements(); ++j) {
+        if (perm.dtype() == DT_INT64) {
+          permutation.push_back(perm.vec<int64>()(j));
+        } else {
+          permutation.push_back(perm.vec<int>()(j));
+        }
+      }
+      if (permutation.size() != shape.dim_size()) {
+        // Number of elements in perm should be same as dim_size. Skip if not.
+        return Status::OK();
+      }
+      // The node is replaceable iff
+      // dim_size == 0 || all dims have size 1 ||
+      // all dims with > 1 size are not permuted.
+      bool replaceable = true;
+      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+        replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
+      }
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+        *success = true;
+        return Status::OK();
+      }
+    }
+  }
+  *success = false;
+  return Status::OK();
+}
 bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
                                           bool use_shape_info,
                                           GraphDef* optimized_graph,
index c760b05..43cabb4 100644 (file)
@@ -199,6 +199,12 @@ class ConstantFolding : public GraphOptimizer {
   bool RemoveRandomShuffle(const GraphProperties& properties,
                            bool use_shape_info, GraphDef* optimized_graph,
                            NodeDef* node);
+
+  // Removes Shuffle or Transpose op over dimensions of size 1.
+  Status RemoveShuffleOrTranspose(const GraphProperties& properties,
+                                  bool use_shape_info,
+                                  GraphDef* optimized_graph, NodeDef* node,
+                                  bool* success);
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;