Strength reduce division by a constant to multiplication by the reciprocal constant.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Dec 2017 00:11:26 +0000 (16:11 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 12 Dec 2017 00:15:17 +0000 (16:15 -0800)
PiperOrigin-RevId: 178689056

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index e1935fa9b3bec8236ff5aeb4422213a526dfb447..ac94c3f81e8d1906bb844841034984e9d38f283f 100644 (file)
@@ -31,6 +31,11 @@ bool IsAdd(const NodeDef& node) {
 
 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
 
+bool IsAnyDiv(const NodeDef& node) {
+  return node.op() == "RealDiv" || node.op() == "Div" ||
+         node.op() == "FloorDiv" || node.op() == "TruncateDiv";
+}
+
 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
 
 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
@@ -74,6 +79,8 @@ bool IsDequeueOp(const NodeDef& node) {
          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
 }
 
+bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
+
 bool IsEnter(const NodeDef& node) {
   const auto& op = node.op();
   return op == "Enter" || op == "RefEnter";
@@ -96,13 +103,13 @@ bool IsIdentity(const NodeDef& node) {
 }
 
 bool IsMatMul(const NodeDef& node) {
-  const auto op = node.op();
+  const auto& op = node.op();
   return op == "MatMul" || op == "BatchMatMul" || op == "QuantizedMatMul" ||
          op == "SparseMatMul";
 }
 
 bool IsMerge(const NodeDef& node) {
-  const auto op = node.op();
+  const auto& op = node.op();
   return op == "Merge" || op == "RefMerge";
 }
 
@@ -118,16 +125,11 @@ bool IsNextIteration(const NodeDef& node) {
 bool IsPad(const NodeDef& node) { return node.op() == "Pad"; }
 
 bool IsPlaceholder(const NodeDef& node) {
-  const auto op = node.op();
+  const auto& op = node.op();
   return op == "Placeholder" || op == "PlaceholderV2" ||
          op == "PlaceholderWithDefault";
 }
 
-bool IsAnyDiv(const NodeDef& node) {
-  return node.op() == "RealDiv" || node.op() == "Div" ||
-         node.op() == "FloorDiv" || node.op() == "TruncateDiv";
-}
-
 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
 
 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
index fc5279c1b87020c2187f287d6ccf0f31f44c9830..b8031e011cf8e77ef635dd4685459c7997e28e1c 100644 (file)
@@ -24,6 +24,7 @@ namespace grappler {
 
 bool IsAdd(const NodeDef& node);
 bool IsAddN(const NodeDef& node);
+bool IsAnyDiv(const NodeDef& node);
 bool IsAvgPoolGrad(const NodeDef& node);
 bool IsAssert(const NodeDef& node);
 bool IsBiasAdd(const NodeDef& node);
@@ -37,6 +38,7 @@ bool IsDepthwiseConv2dNative(const NodeDef& node);
 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
 bool IsDequeueOp(const NodeDef& node);
+bool IsDiv(const NodeDef& node);
 bool IsEnter(const NodeDef& node);
 bool IsExit(const NodeDef& node);
 bool IsFloorMod(const NodeDef& node);
@@ -49,7 +51,6 @@ bool IsNextIteration(const NodeDef& node);
 bool IsPad(const NodeDef& node);
 bool IsNoOp(const NodeDef& node);
 bool IsPlaceholder(const NodeDef& node);
-bool IsAnyDiv(const NodeDef& node);
 bool IsRealDiv(const NodeDef& node);
 bool IsReluGrad(const NodeDef& node);
 bool IsRecv(const NodeDef& node);
index cb9a5fde2e0d3e8df488d94789fd131ef0e16276..d90fe5704007fcff4f23fa84a0c0e858beca0da3 100644 (file)
@@ -1072,6 +1072,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
     }
     // We need to record a copy of output nodes before FoldNode() modifies it.
     std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
+
     Status s = FoldNode(node, output);
     processed_nodes.insert(node->name());
     if (!s.ok()) {
@@ -1305,56 +1306,59 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
                                       const GraphProperties& properties,
                                       bool use_shape_info) {
   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
-  for (auto& node : *output->mutable_node()) {
-    if (IsSimplifiableReduction(node)) {
+  for (int i = 0; i < output->node_size(); ++i) {
+    NodeDef* node = output->mutable_node(i);
+    if (IsSimplifiableReduction(*node)) {
       // Replace the reduction node with an identity node, that can be further
       // optimized by the model pruner.
       DataType output_type;
-      if (node.attr().count("T") > 0) {
-        output_type = node.attr().at("T").type();
+      if (node->attr().count("T") > 0) {
+        output_type = node->attr().at("T").type();
       } else {
         // This is an 'any' or 'all' reduction. The output is always boolean.
         output_type = DT_BOOL;
       }
-      node.set_op("Identity");
-      node.clear_attr();
-      (*node.mutable_attr())["T"].set_type(output_type);
-      *node.mutable_input(1) = AsControlDependency(node.input(1));
+      node->set_op("Identity");
+      node->clear_attr();
+      (*node->mutable_attr())["T"].set_type(output_type);
+      *node->mutable_input(1) = AsControlDependency(node->input(1));
+      continue;
     }
     const bool safe_to_use_shapes =
         use_shape_info && (feed_nodes_.empty() || is_aggressive);
-    if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) {
-      DataType output_type = node.attr().at("T").type();
-      node.set_op("Identity");
-      node.clear_attr();
-      (*node.mutable_attr())["T"].set_type(output_type);
-      *node.mutable_input(1) = AsControlDependency(node.input(1));
+    if (safe_to_use_shapes && IsSimplifiableReshape(*node, properties)) {
+      DataType output_type = node->attr().at("T").type();
+      node->set_op("Identity");
+      node->clear_attr();
+      (*node->mutable_attr())["T"].set_type(output_type);
+      *node->mutable_input(1) = AsControlDependency(node->input(1));
+      continue;
     }
 
+    const bool is_mul = IsMul(*node);
+    const bool is_matmul = IsMatMul(*node);
+    const bool is_add = IsAdd(*node) || IsBiasAdd(*node);
+    const bool is_sub = IsSub(*node);
+    const bool is_any_div = IsAnyDiv(*node);
     // Simplify multiplication by ones or zeros, and addition/subtraction of
     // zeros.
-    // TODO(rmlarsen): Rewrite x / const  -> x * (1/const).
-    bool is_mul = IsMul(node);
-    bool is_matmul = IsMatMul(node);
-    bool is_add = IsAdd(node) || IsBiasAdd(node);
-    bool is_sub = IsSub(node);
-    bool is_div = IsAnyDiv(node);
-    if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_div) &&
-        properties.HasInputProperties(node.name()) &&
-        properties.HasOutputProperties(node.name())) {
-      const NodeDef* x = node_map_->GetNode(node.input(0));
-      const NodeDef* y = node_map_->GetNode(node.input(1));
+    if (use_shape_info &&
+        (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
+        properties.HasInputProperties(node->name()) &&
+        properties.HasOutputProperties(node->name())) {
+      const NodeDef* x = node_map_->GetNode(node->input(0));
+      const NodeDef* y = node_map_->GetNode(node->input(1));
       if (x == nullptr || y == nullptr) {
         return errors::InvalidArgument("Invalid inputs to node: ",
-                                       node.DebugString());
+                                       node->DebugString());
       }
       const TensorShapeProto& output_shape =
-          properties.GetOutputProperties(node.name())[0].shape();
+          properties.GetOutputProperties(node->name())[0].shape();
 
       // Simplify element-wise  multiplication by ones or addition/subtraction
       // of zeros.
       const TensorShapeProto& y_shape =
-          properties.GetInputProperties(node.name())[1].shape();
+          properties.GetInputProperties(node->name())[1].shape();
       const bool x_is_zero = IsZeros(*x);
       const bool x_is_one = IsOnes(*x);
       const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
@@ -1362,52 +1366,91 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
           ((is_mul && x_is_one) || (is_add && x_is_zero))) {
         // TODO(rmlarsen): Handle subtraction 0 - y.
         // 1 * y = y or 0 + y = y.
-        ReplaceOperationWithIdentity(1, &node);
+        ReplaceOperationWithIdentity(1, node);
         continue;
       }
 
       // Replace 1 / y with Reciprocal op.
-      if (y_matches_output_shape && is_div && x_is_one) {
-        ReplaceDivisionOfOnesByReciprocal(&node);
+      if (y_matches_output_shape && is_any_div && x_is_one) {
+        ReplaceDivisionOfOnesByReciprocal(node);
         continue;
       }
 
       const TensorShapeProto& x_shape =
-          properties.GetInputProperties(node.name())[0].shape();
+          properties.GetInputProperties(node->name())[0].shape();
       const bool y_is_zero = IsZeros(*y);
       const bool y_is_one = IsOnes(*y);
       const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
       if (x_matches_output_shape &&
-          (((is_mul || is_div) && y_is_one) ||
+          (((is_mul || is_any_div) && y_is_one) ||
            ((is_add || is_sub) && y_is_zero && is_aggressive))) {
         // x * 1 = x or x / 1 = x or x +/- 0 = x
-        ReplaceOperationWithIdentity(0, &node);
+        ReplaceOperationWithIdentity(0, node);
         continue;
       }
 
       // Simplify multiplication and matmul by zeros.
       // Also optimize zeros divided by a tensor, but only if we are in
       // aggressive mode, since we might get rid of divisions by zero.
-      bool optimize_zeros_divided_by_y = is_div && x_is_zero && is_aggressive;
+      bool optimize_zeros_divided_by_y =
+          is_any_div && x_is_zero && is_aggressive;
       if ((x_is_zero || y_is_zero) &&
           (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
         const PartialTensorShape shp(output_shape);
         if (shp.IsFullyDefined()) {
           TF_RETURN_IF_ERROR(
-              ReplaceOperationWithConstant(0, output_shape, &node));
+              ReplaceOperationWithConstant(0, output_shape, node));
           continue;
         }
         // Even if an input shape is only partially known, we may known that it
         // matches the output shape and thus forward the corresponding zero
         // input.
-        if ((is_mul || is_div) && x_is_zero && x_matches_output_shape) {
-          ReplaceOperationWithIdentity(0, &node);
+        if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
+          ReplaceOperationWithIdentity(0, node);
+          continue;
         } else if (is_mul && y_is_zero && y_matches_output_shape) {
-          ReplaceOperationWithIdentity(1, &node);
+          ReplaceOperationWithIdentity(1, node);
+          continue;
         }
       }
     }
+
+    // Strength reduce floating point division by a constant Div(x, const) to
+    // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
+    // will be constant folded to Mul(x, 1.0/const).
+    if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
+      const string& const_input = node->input(1);
+      const NodeDef* denom = node_map_->GetNode(const_input);
+      CHECK(denom != nullptr);
+      if (!IsReallyConstant(*denom)) {
+        continue;
+      }
+      if (node->attr().count("T") == 0) {
+        continue;
+      }
+      DataType type = node->attr().at("T").type();
+      if (IsDiv(*node) && !DataTypeIsFloating(type)) {
+        continue;
+      }
+      // Insert new reciprocal op and change node from Div to Mul.
+      NodeDef* reciprocal_node = output->add_node();
+      reciprocal_node->set_name(AddPrefixToNodeName(
+          strings::StrCat(node->name(), "_recip"), kConstantFoldingConst));
+      reciprocal_node->set_op("Reciprocal");
+      reciprocal_node->set_device(node->device());
+      node->set_op("Mul");
+      // Re-wire inputs and outputs.
+      reciprocal_node->add_input(const_input);
+      (*reciprocal_node->mutable_attr())["T"].set_type(type);
+      node->set_input(1, reciprocal_node->name());
+      node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
+      node_map_->UpdateInput(node->name(), const_input,
+                             reciprocal_node->name());
+      node_map_->AddOutput(NodeName(const_input), reciprocal_node->name());
+      graph_modified_ = true;
+    }
   }
+
   return Status::OK();
 }
 
@@ -1444,6 +1487,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   }
 
   TF_RETURN_IF_ERROR(FoldGraph(output));
+  node_map_.reset(new NodeMap(output));
   TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
 
   return Status::OK();
index 7fc88cd466df6ada5148418afb7a4786c5eb5b04..813d0cdcb0d856adfd8e7c6bd72724413b435163 100644 (file)
@@ -232,6 +232,74 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
   }
 }
 
+TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output cf_half = ops::Const(s.WithOpName("cf_half"), 0.5f, {1});
+  Output xf = ops::Placeholder(s.WithOpName("xf"), DT_FLOAT,
+                               ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output xi = ops::Placeholder(s.WithOpName("xi"), DT_INT32,
+                               ops::Placeholder::Shape(TensorShape({2, 2})));
+  Output ci = ops::Const(s.WithOpName("ci"), 2, {1});
+  Output cf = ops::Const(s.WithOpName("cf"), 2.0f, {1});
+  Output div_i = ops::Div(s.WithOpName("div_i"), xi, ci);
+  Output div_f = ops::Div(s.WithOpName("div_f"), xf, cf);
+  Output realdiv = ops::RealDiv(s.WithOpName("realdiv"), xf, cf);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"div_f", "div_i", "realdiv"};
+  ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+                            nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(8, output.node_size());
+  for (int i = 0; i < output.node_size(); ++i) {
+    const NodeDef& node = output.node(i);
+    const string& name = node.name();
+    if (name == "div_i") {
+      // Integer division is unchanged.
+      EXPECT_EQ("Div", node.op());
+      EXPECT_EQ("xi", node.input(0));
+      EXPECT_EQ("ci", node.input(1));
+    } else if (name == "div_f") {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ("xf", node.input(0));
+      EXPECT_EQ("ConstantFolding/div_f_recip", node.input(1));
+    } else if (name == "realdiv") {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ("xf", node.input(0));
+      EXPECT_EQ("ConstantFolding/realdiv_recip", node.input(1));
+    } else if (name == "ConstantFolding/div_f_recip") {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
+      TensorProto t = node.attr().at("value").tensor();
+      EXPECT_EQ(DT_FLOAT, t.dtype());
+      EXPECT_EQ(1, t.tensor_shape().dim_size());
+      EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+    } else if (name == "ConstantFolding/realdiv_recip") {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(DT_FLOAT, node.attr().at("dtype").type());
+      TensorProto t = node.attr().at("value").tensor();
+      EXPECT_EQ(DT_FLOAT, t.dtype());
+      EXPECT_EQ(1, t.tensor_shape().dim_size());
+      EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+    }
+  }
+
+  // Check that the reciprocals have the expected value.
+  std::vector<string> fetch = {"cf_half"};
+  auto tensor_expected = EvaluateNodes(item.graph, fetch);
+  EXPECT_EQ(fetch.size(), tensor_expected.size());
+  fetch = {"ConstantFolding/div_f_recip", "ConstantFolding/realdiv_recip"};
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(fetch.size(), tensors.size());
+  for (int i = 0; i < fetch.size(); i++) {
+    test::ExpectTensorEqual<float>(tensor_expected[0], tensors[i]);
+  }
+}
+
 TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output x_known =