Extracts the 'remove random shuffle node' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 26 May 2018 15:25:12 +0000 (08:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 15:27:46 +0000 (08:27 -0700)
PiperOrigin-RevId: 198169790

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index df32d4a..fed5d87 100644 (file)
@@ -1695,17 +1695,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
       }
     }
   }
-  // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
-  if (use_shape_info && IsRandomShuffle(*node) &&
-      !properties->GetInputProperties(node->name()).empty()) {
-    const auto& shape = properties->GetInputProperties(node->name())[0].shape();
-    // The node is replaceable iff
-    // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
-    if (!shape.unknown_rank() &&
-        (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
-      ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-      return Status::OK();
-    }
+
+  if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
+    return Status::OK();
   }
 
   bool remove_reverse_successful = false;
@@ -1831,6 +1823,24 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
+                                          bool use_shape_info,
+                                          GraphDef* optimized_graph,
+                                          NodeDef* node) {
+  if (use_shape_info && IsRandomShuffle(*node) &&
+      !properties.GetInputProperties(node->name()).empty()) {
+    const auto& shape = properties.GetInputProperties(node->name())[0].shape();
+    // The node is replaceable iff
+    // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
+    if (!shape.unknown_rank() &&
+        (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+      ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+      return true;
+    }
+  }
+  return false;
+}
+
 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
                                       bool use_shape_info,
                                       GraphDef* optimized_graph, NodeDef* node,
index 9a3ea03..c760b05 100644 (file)
@@ -194,6 +194,11 @@ class ConstantFolding : public GraphOptimizer {
   // Removes Reverse op over dimensions with size 1.
   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
                        GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+  // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
+  bool RemoveRandomShuffle(const GraphProperties& properties,
+                           bool use_shape_info, GraphDef* optimized_graph,
+                           NodeDef* node);
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;