Extracts the following optimizations into methods:
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 18 May 2018 13:27:13 +0000 (06:27 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 18 May 2018 13:30:21 +0000 (06:30 -0700)
SimplifyArithmeticOperations
ReduceDivToReciprocalMul

PiperOrigin-RevId: 197137281

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index 782ccff..9137b9d 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #define EIGEN_USE_THREADS
 
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
+
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/function.pb.h"
@@ -1566,9 +1567,13 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
 
 Status ConstantFolding::ReplaceOperationWithConstant(
     double value, const GraphProperties& properties,
-    const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
+    const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
+    bool* success) {
   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
-  if (dtype == DT_INVALID) return Status::OK();
+  if (dtype == DT_INVALID) {
+    *success = false;
+    return Status::OK();
+  }
 
   AttrValue tensor_attr;
   TF_RETURN_IF_ERROR(
@@ -1587,7 +1592,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
     node->set_input(i, ctrl_dep);
   }
-  graph_modified_ = true;
+  *success = true;
   return Status::OK();
 }
 
@@ -1605,7 +1610,6 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
 Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
                                      GraphProperties* properties,
                                      bool use_shape_info) {
-  const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
     ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
     return Status::OK();
@@ -2029,6 +2033,48 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     return Status::OK();
   }
 
+  bool arithmetic_simplification_succeed = false;
+  Status simplify_arithmetic_status = SimplifyArithmeticOperations(
+      optimized_graph, properties, node, use_shape_info,
+      &arithmetic_simplification_succeed);
+  if (!simplify_arithmetic_status.ok()) {
+    return simplify_arithmetic_status;
+  } else if (arithmetic_simplification_succeed) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (ReduceDivToReciprocalMul(optimized_graph, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (ConstantPushDown(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConstPropThroughIdentityN(node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
+    graph_modified_ = true;
+    return Status::OK();
+  }
+
+  return Status::OK();
+}
+
+Status ConstantFolding::SimplifyArithmeticOperations(
+    GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node,
+    bool use_shape_info, bool* success) {
   const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
   const bool is_matmul = IsMatMul(*node);
   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
@@ -2059,12 +2105,14 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
         ((is_mul && x_is_one) || (is_add && x_is_zero))) {
       // 1 * y = y or 0 + y = y.
       ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph);
+      *success = true;
       return Status::OK();
     }
 
     if (y_matches_output_shape && (is_sub && x_is_zero)) {
       // Replace 0 - y with Neg(y).
       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
+      *success = true;
       return Status::OK();
     }
 
@@ -2073,6 +2121,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
       DataType type = node->attr().at("T").type();
       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
+        *success = true;
         return Status::OK();
       }
     }
@@ -2086,40 +2135,68 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
                                    ((is_add || is_sub) && y_is_zero))) {
       // x * 1 = x or x / 1 = x or x +/- 0 = x
       ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph);
+      *success = true;
       return Status::OK();
     }
 
     // x OR true = true OR y = true.
+    bool updated_graph = false;
     const PartialTensorShape shp(output_shape);
     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
-      TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
-          1, *properties, output_shape, node, optimized_graph));
+      bool replace_succeed = false;
+      Status replace_op_status =
+          ReplaceOperationWithConstant(1, *properties, output_shape, node,
+                                       optimized_graph, &replace_succeed);
+      if (!replace_op_status.ok()) {
+        return replace_op_status;
+      } else if (replace_succeed) {
+        updated_graph = true;
+      }
     }
 
     // Simplify multiplication and matmul by zeros.
     // Also optimize zeros divided by a tensor, but only if we are in
     // aggressive mode, since we might get rid of divisions by zero.
+    const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
     if ((x_is_zero || y_is_zero) &&
         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
       if (shp.IsFullyDefined()) {
-        TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
-            0, *properties, output_shape, node, optimized_graph));
-        return Status::OK();
+        bool replace_succeed = false;
+        Status replace_op_status =
+            ReplaceOperationWithConstant(0, *properties, output_shape, node,
+                                         optimized_graph, &replace_succeed);
+        if (!replace_op_status.ok()) {
+          return replace_op_status;
+        } else if (replace_succeed) {
+          *success = true;
+          return Status::OK();
+        }
       }
       // Even if an input shape is only partially known, we may known that it
       // matches the output shape and thus forward the corresponding zero
       // input.
       if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
         ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
+        *success = true;
         return Status::OK();
       } else if (is_mul && y_is_zero && y_matches_output_shape) {
         ReplaceOperationWithIdentity(1, *properties, node, optimized_graph);
+        *success = true;
         return Status::OK();
       }
     }
+    if (updated_graph) {
+      *success = true;
+      return Status::OK();
+    }
   }
+  *success = false;
+  return Status::OK();
+}
 
+bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
+                                               NodeDef* node) {
   // Strength reduce floating point division by a constant Div(x, const) to
   // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
   // will be constant folded to Mul(x, 1.0/const).
@@ -2128,15 +2205,15 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     const NodeDef* denom = node_map_->GetNode(const_input);
     CHECK(denom != nullptr);
     if (!IsReallyConstant(*denom)) {
-      return Status::OK();
+      return false;
     }
     if (node->attr().count("T") == 0) {
-      return Status::OK();
+      return false;
     }
     DataType type = node->attr().at("T").type();
     if (IsDiv(*node) &&
         !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
-      return Status::OK();
+      return false;
     }
     // Insert new reciprocal op and change node from Div to Mul.
     NodeDef* reciprocal_node = optimized_graph->add_node();
@@ -2150,31 +2227,9 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
     node->set_input(1, reciprocal_node->name());
     node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (ConstantPushDown(node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConstPropThroughIdentityN(node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
+    return true;
   }
-
-  return Status::OK();
+  return false;
 }
 
 bool ConstantFolding::ConstantPushDown(NodeDef* node) {
index 227caba..6c99120 100644 (file)
@@ -88,7 +88,8 @@ class ConstantFolding : public GraphOptimizer {
   Status ReplaceOperationWithConstant(double value,
                                       const GraphProperties& properties,
                                       const TensorShapeProto& shape,
-                                      NodeDef* node, GraphDef* graph);
+                                      NodeDef* node, GraphDef* graph,
+                                      bool* success);
   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
   Status FoldGraph(GraphDef* output);
 
@@ -121,6 +122,18 @@ class ConstantFolding : public GraphOptimizer {
   // the transformation applied successfully.
   bool ConstantPushDown(NodeDef* node);
 
+  // Strength reduces floating point division by a constant Div(x, const) to
+  // multiplication by the reciprocal Mul(x, Reciprocal(const)).
+  bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
+
+  // Simplifies arithmetic operations with ones or zeros. Returns the status,
+  // and updates the success input argument that denotes if any simplification
+  // was applied.
+  Status SimplifyArithmeticOperations(GraphDef* optimized_graph,
+                                      GraphProperties* properties,
+                                      NodeDef* node, bool use_shape_info,
+                                      bool* success);
+
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;