Extracts the 'remove split or splitv nodes' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 29 May 2018 18:55:17 +0000 (11:55 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 18:58:17 +0000 (11:58 -0700)
PiperOrigin-RevId: 198432976

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 90862b0..1ea916a 100644 (file)
@@ -1645,13 +1645,7 @@ Status ConstantFolding::SimplifyGraph(bool use_shape_info,
 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
                                      GraphDef* optimized_graph,
                                      GraphProperties* properties) {
-  if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
-    ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
-    return Status::OK();
-  }
-
-  if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
-    ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+  if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
     return Status::OK();
   }
 
@@ -1792,6 +1786,21 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
+                                          GraphDef* optimized_graph,
+                                          NodeDef* node) {
+  if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
+    ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
+    return true;
+  }
+
+  if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
+    ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+    return true;
+  }
+  return false;
+}
+
 Status ConstantFolding::RemoveShuffleOrTranspose(
     const GraphProperties& properties, bool use_shape_info,
     GraphDef* optimized_graph, NodeDef* node, bool* success) {
index 43cabb4..b42d5f2 100644 (file)
@@ -205,6 +205,10 @@ class ConstantFolding : public GraphOptimizer {
                                   bool use_shape_info,
                                   GraphDef* optimized_graph, NodeDef* node,
                                   bool* success);
+
+  // Removes Split or SplitV node if possible.
+  bool RemoveSplitOrSplitV(const GraphProperties& properties,
+                           GraphDef* optimized_graph, NodeDef* node);
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;