From: A. Unique TensorFlower Date: Thu, 26 Apr 2018 21:59:29 +0000 (-0700) Subject: Adds optimization to convert division of sqrt to multiplication of rsqrt X-Git-Tag: upstream/v1.9.0_rc1~206^2~1^2~21 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=4386296d48d84aceb485c09361f7b80745806a61;p=platform%2Fupstream%2Ftensorflow.git Adds optimization to convert division of sqrt to multiplication of rsqrt PiperOrigin-RevId: 194459152 --- diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index c024303..7a89c26 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -289,6 +289,8 @@ bool IsReverse(const NodeDef& node) { bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } +bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; } + bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } bool IsSelect(const NodeDef& node) { return node.op() == "Select"; } @@ -317,6 +319,8 @@ bool IsSplit(const NodeDef& node) { return node.op() == "Split"; } bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; } +bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; } + bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; } bool IsSquare(const NodeDef& node) { return node.op() == "Square"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 3cba6b8..976d23e 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -110,6 +110,7 @@ bool IsReshape(const NodeDef& node); bool IsRestore(const NodeDef& node); bool IsReverse(const NodeDef& node); bool IsReverseV2(const NodeDef& node); +bool IsRsqrt(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); bool IsSeluGrad(const NodeDef& node); @@ -123,6 +124,7 @@ bool IsSoftplusGrad(const NodeDef& node); bool IsSoftsignGrad(const NodeDef& node); bool IsSplit(const NodeDef& node); bool IsSplitV(const NodeDef& node); +bool IsSqrt(const NodeDef& node); bool IsSqrtGrad(const NodeDef& node); bool IsSquare(const NodeDef& node); bool IsSquaredDifference(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index c0bd0bd..18076ee 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1515,6 +1515,36 @@ class HoistCWiseUnaryFromConcatStage : public ArithmeticOptimizerStage { } }; +// Performs the conversion: +// Div(x, Sqrt(y)) => Mul(x, Rsqrt(y)) +// TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)). +class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage { + public: + explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {} + ~SqrtDivToRsqrtMulStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAnyDiv(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* y; + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + // Optimize only if divisor is a Sqrt whose output is not being consumed + // elsewhere. + if (IsSqrt(*y) && (NumNonControlOutputs(*y, *ctx().node_map) == 1)) { + // a / sqrt(b) = a * rsqrt(b) + node->set_op("Mul"); + y->set_op("Rsqrt"); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } + return Status::OK(); + } +}; + } // namespace class UniqueNodes { @@ -2172,6 +2202,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.hoist_unary_out_of_concat) pipeline.AddStage(ctx, ctx_ext); + if (options_.convert_sqrt_div_to_rsqrt_mul) + pipeline.AddStage(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 689ffd4..24a2a50 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -66,6 +66,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_cast = true; bool remove_negation = true; bool hoist_unary_out_of_concat = false; + bool convert_sqrt_div_to_rsqrt_mul = false; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index df10dbd..7485d99 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -148,10 +148,16 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.remove_negation = true; } + void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.hoist_unary_out_of_concat = true; } + + void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -1936,6 +1942,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { EXPECT_EQ(5, found); } +TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2}); + Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y); + Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y); + + GrapplerItem item; + item.fetch = {"output"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlySqrtDivToRsqrtMul(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "output") { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("sqrt_y", node.input(1)); + } else if (node.name() == "sqrt_y") { + EXPECT_EQ("Rsqrt", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("y", node.input(0)); + } + } +} + TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) { tensorflow::Scope s = tensorflow::Scope::NewRootScope();