Adds optimization to convert division of sqrt to multiplication of rsqrt
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Apr 2018 21:59:29 +0000 (14:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 22:03:02 +0000 (15:03 -0700)
PiperOrigin-RevId: 194459152

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index c024303..7a89c26 100644 (file)
@@ -289,6 +289,8 @@ bool IsReverse(const NodeDef& node) {
 
 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
 
+bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
+
 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
 
 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
@@ -317,6 +319,8 @@ bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
 
 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
 
+bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
+
 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
 
 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
index 3cba6b8..976d23e 100644 (file)
@@ -110,6 +110,7 @@ bool IsReshape(const NodeDef& node);
 bool IsRestore(const NodeDef& node);
 bool IsReverse(const NodeDef& node);
 bool IsReverseV2(const NodeDef& node);
+bool IsRsqrt(const NodeDef& node);
 bool IsRsqrtGrad(const NodeDef& node);
 bool IsSelect(const NodeDef& node);
 bool IsSeluGrad(const NodeDef& node);
@@ -123,6 +124,7 @@ bool IsSoftplusGrad(const NodeDef& node);
 bool IsSoftsignGrad(const NodeDef& node);
 bool IsSplit(const NodeDef& node);
 bool IsSplitV(const NodeDef& node);
+bool IsSqrt(const NodeDef& node);
 bool IsSqrtGrad(const NodeDef& node);
 bool IsSquare(const NodeDef& node);
 bool IsSquaredDifference(const NodeDef& node);
index c0bd0bd..18076ee 100644 (file)
@@ -1515,6 +1515,36 @@ class HoistCWiseUnaryFromConcatStage : public ArithmeticOptimizerStage {
   }
 };
 
+// Performs the conversion:
+// Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
+// TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
+class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
+ public:
+  explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
+                                  const ArithmeticOptimizerContext& ctx_ext)
+      : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
+  ~SqrtDivToRsqrtMulStage() override = default;
+
+  bool IsSupported(const NodeDef* node) const override {
+    return IsAnyDiv(*node);
+  }
+
+  Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+    NodeDef* y;
+    TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+    // Optimize only if divisor is a Sqrt whose output is not being consumed
+    // elsewhere.
+    if (IsSqrt(*y) && (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
+      // a / sqrt(b) = a * rsqrt(b)
+      node->set_op("Mul");
+      y->set_op("Rsqrt");
+      AddToOptimizationQueue(node);
+      AddToOptimizationQueue(y);
+    }
+    return Status::OK();
+  }
+};
+
 }  // namespace
 
 class UniqueNodes {
@@ -2172,6 +2202,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
   if (options_.hoist_unary_out_of_concat)
     pipeline.AddStage<HoistCWiseUnaryFromConcatStage>(ctx, ctx_ext);
+  if (options_.convert_sqrt_div_to_rsqrt_mul)
+    pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
 
   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
           << str_util::Join(pipeline.StageNames(), ", ");
index 689ffd4..24a2a50 100644 (file)
@@ -66,6 +66,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
     bool remove_redundant_cast = true;
     bool remove_negation = true;
     bool hoist_unary_out_of_concat = false;
+    bool convert_sqrt_div_to_rsqrt_mul = false;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
     // optimization level by default.
index df10dbd..7485d99 100644 (file)
@@ -148,10 +148,16 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     DisableAllStages(optimizer);
     optimizer->options_.remove_negation = true;
   }
+
   void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
     optimizer->options_.hoist_unary_out_of_concat = true;
   }
+
+  void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
+  }
 };
 
 TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -1936,6 +1942,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
   EXPECT_EQ(5, found);
 }
 
+TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+  auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+  Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
+  Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
+
+  GrapplerItem item;
+  item.fetch = {"output"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlySqrtDivToRsqrtMul(&optimizer);
+  OptimizeAndPrune(&optimizer, &item, &output);
+  auto tensors = EvaluateNodes(output, item.fetch);
+  EXPECT_EQ(1, tensors.size());
+
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+  EXPECT_EQ(item.graph.node_size(), output.node_size());
+  for (int i = 0; i < output.node_size(); ++i) {
+    const NodeDef& node = output.node(i);
+    if (node.name() == "output") {
+      EXPECT_EQ("Mul", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("sqrt_y", node.input(1));
+    } else if (node.name() == "sqrt_y") {
+      EXPECT_EQ("Rsqrt", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("y", node.input(0));
+    }
+  }
+}
+
 TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();