Extracts the SimplifyReshape optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 23 May 2018 19:35:05 +0000 (12:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 23 May 2018 19:37:32 +0000 (12:37 -0700)
PiperOrigin-RevId: 197770994

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 8bdb164..4ebe1ca 100644 (file)
@@ -1631,20 +1631,20 @@ Status ConstantFolding::ReplaceOperationWithConstant(
   return Status::OK();
 }
 
-Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
-                                      GraphProperties* properties,
-                                      bool use_shape_info) {
+Status ConstantFolding::SimplifyGraph(bool use_shape_info,
+                                      GraphDef* optimized_graph,
+                                      GraphProperties* properties) {
   for (int i = 0; i < optimized_graph->node_size(); ++i) {
-    TF_RETURN_IF_ERROR(SimplifyNode(optimized_graph->mutable_node(i),
-                                    optimized_graph, properties,
-                                    use_shape_info));
+    TF_RETURN_IF_ERROR(SimplifyNode(use_shape_info,
+                                    optimized_graph->mutable_node(i),
+                                    optimized_graph, properties));
   }
   return Status::OK();
 }
 
-Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
-                                     GraphProperties* properties,
-                                     bool use_shape_info) {
+Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
+                                     GraphDef* optimized_graph,
+                                     GraphProperties* properties) {
   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
     ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
     return Status::OK();
@@ -2150,20 +2150,16 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     graph_modified_ = true;
     return Status::OK();
   }
-  if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
-    DataType output_type = node->attr().at("T").type();
-    node->set_op("Identity");
-    node->clear_attr();
-    (*node->mutable_attr())["T"].set_type(output_type);
-    *node->mutable_input(1) = AsControlDependency(node->input(1));
+
+  if (SimplifyReshape(*properties, use_shape_info, node)) {
     graph_modified_ = true;
     return Status::OK();
   }
 
   bool arithmetic_simplification_succeed = false;
-  Status simplify_arithmetic_status = SimplifyArithmeticOperations(
-      optimized_graph, properties, node, use_shape_info,
-      &arithmetic_simplification_succeed);
+  Status simplify_arithmetic_status =
+      SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
+                                   node, &arithmetic_simplification_succeed);
   if (!simplify_arithmetic_status.ok()) {
     return simplify_arithmetic_status;
   } else if (arithmetic_simplification_succeed) {
@@ -2204,9 +2200,21 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
   return Status::OK();
 }
 
+bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
+                                      bool use_shape_info, NodeDef* node) {
+  if (!use_shape_info) return false;
+  if (!IsSimplifiableReshape(*node, properties)) return false;
+  DataType output_type = node->attr().at("T").type();
+  node->set_op("Identity");
+  node->clear_attr();
+  (*node->mutable_attr())["T"].set_type(output_type);
+  *node->mutable_input(1) = AsControlDependency(node->input(1));
+  return true;
+}
+
 Status ConstantFolding::SimplifyArithmeticOperations(
-    GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node,
-    bool use_shape_info, bool* success) {
+    const GraphProperties& properties, bool use_shape_info,
+    GraphDef* optimized_graph, NodeDef* node, bool* success) {
   const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
   const bool is_matmul = IsMatMul(*node);
   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
@@ -2215,8 +2223,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
   // Simplify arithmetic operations with ones or zeros.
   if (use_shape_info &&
       (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
-      properties->HasInputProperties(node->name()) &&
-      properties->HasOutputProperties(node->name())) {
+      properties.HasInputProperties(node->name()) &&
+      properties.HasOutputProperties(node->name())) {
     const NodeDef* x = node_map_->GetNode(node->input(0));
     const NodeDef* y = node_map_->GetNode(node->input(1));
     if (x == nullptr || y == nullptr) {
@@ -2224,19 +2232,19 @@ Status ConstantFolding::SimplifyArithmeticOperations(
                                      node->DebugString());
     }
     const TensorShapeProto& output_shape =
-        properties->GetOutputProperties(node->name())[0].shape();
+        properties.GetOutputProperties(node->name())[0].shape();
 
     // Simplify element-wise multiplication by ones or addition/subtraction
     // of zeros.
     const TensorShapeProto& y_shape =
-        properties->GetInputProperties(node->name())[1].shape();
+        properties.GetInputProperties(node->name())[1].shape();
     const bool x_is_zero = IsZeros(*x);
     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
     const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
     if (y_matches_output_shape &&
         ((is_mul && x_is_one) || (is_add && x_is_zero))) {
       // 1 * y = y or 0 + y = y.
-      ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+      ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
       *success = true;
       return Status::OK();
     }
@@ -2259,14 +2267,14 @@ Status ConstantFolding::SimplifyArithmeticOperations(
     }
 
     const TensorShapeProto& x_shape =
-        properties->GetInputProperties(node->name())[0].shape();
+        properties.GetInputProperties(node->name())[0].shape();
     const bool y_is_zero = IsZeros(*y);
     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
     const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
     if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
                                    ((is_add || is_sub) && y_is_zero))) {
       // x * 1 = x or x / 1 = x or x +/- 0 = x
-      ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+      ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
       *success = true;
       return Status::OK();
     }
@@ -2276,9 +2284,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
     const PartialTensorShape shp(output_shape);
     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
       bool replace_succeed = false;
-      Status replace_op_status =
-          ReplaceOperationWithConstant(1, *properties, output_shape, node,
-                                       optimized_graph, &replace_succeed);
+      Status replace_op_status = ReplaceOperationWithConstant(
+          1, properties, output_shape, node, optimized_graph, &replace_succeed);
       if (!replace_op_status.ok()) {
         return replace_op_status;
       } else if (replace_succeed) {
@@ -2296,7 +2303,7 @@ Status ConstantFolding::SimplifyArithmeticOperations(
       if (shp.IsFullyDefined()) {
         bool replace_succeed = false;
         Status replace_op_status =
-            ReplaceOperationWithConstant(0, *properties, output_shape, node,
+            ReplaceOperationWithConstant(0, properties, output_shape, node,
                                          optimized_graph, &replace_succeed);
         if (!replace_op_status.ok()) {
           return replace_op_status;
@@ -2309,11 +2316,11 @@ Status ConstantFolding::SimplifyArithmeticOperations(
       // matches the output shape and thus forward the corresponding zero
       // input.
       if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
-        ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
         *success = true;
         return Status::OK();
       } else if (is_mul && y_is_zero && y_matches_output_shape) {
-        ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+        ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
         *success = true;
         return Status::OK();
       }
@@ -2855,7 +2862,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   TF_RETURN_IF_ERROR(FoldGraph(optimized_graph));
   node_map_.reset(new NodeMap(optimized_graph));
   TF_RETURN_IF_ERROR(
-      SimplifyGraph(optimized_graph, &properties, can_use_shape_info));
+      SimplifyGraph(can_use_shape_info, optimized_graph, &properties));
 
   return Status::OK();
 }
index e477934..3cf379f 100644 (file)
@@ -97,10 +97,10 @@ class ConstantFolding : public GraphOptimizer {
                                const GraphProperties& properties) const;
   bool IsSimplifiableReshape(const NodeDef& node,
                              const GraphProperties& properties) const;
-  Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
-                       bool use_shape_info);
-  Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
-                      GraphProperties* properties, bool use_shape_info);
+  Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
+                       GraphProperties* properties);
+  Status SimplifyNode(bool use_shape_info, NodeDef* node,
+                      GraphDef* optimized_graph, GraphProperties* properties);
 
   Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
                              GraphDef* output);
@@ -134,11 +134,16 @@ class ConstantFolding : public GraphOptimizer {
   // Simplifies arithmetic operations with ones or zeros. Returns the status,
   // and updates the success input argument that denotes if any simplification
   // was applied.
-  Status SimplifyArithmeticOperations(GraphDef* optimized_graph,
-                                      GraphProperties* properties,
-                                      NodeDef* node, bool use_shape_info,
+  Status SimplifyArithmeticOperations(const GraphProperties& properties,
+                                      bool use_shape_info,
+                                      GraphDef* optimized_graph, NodeDef* node,
                                       bool* success);
 
+  // Simplifies a Reshape operation to an Identity operation if the input node
+  // to the operation is a constant.
+  bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
+                       NodeDef* node);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;