bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
+bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
+
bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
+bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
+
bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
bool IsRestore(const NodeDef& node);
bool IsReverse(const NodeDef& node);
bool IsReverseV2(const NodeDef& node);
+bool IsRsqrt(const NodeDef& node);
bool IsRsqrtGrad(const NodeDef& node);
bool IsSelect(const NodeDef& node);
bool IsSeluGrad(const NodeDef& node);
bool IsSoftsignGrad(const NodeDef& node);
bool IsSplit(const NodeDef& node);
bool IsSplitV(const NodeDef& node);
+bool IsSqrt(const NodeDef& node);
bool IsSqrtGrad(const NodeDef& node);
bool IsSquare(const NodeDef& node);
bool IsSquaredDifference(const NodeDef& node);
}
};
+// Performs the conversion:
+// Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
+// TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
+class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
+ public:
+ explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
+ ~SqrtDivToRsqrtMulStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsAnyDiv(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* y;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ // Optimize only if divisor is a Sqrt whose output is not being consumed
+ // elsewhere.
+ if (IsSqrt(*y) && (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
+ // a / sqrt(b) = a * rsqrt(b)
+ node->set_op("Mul");
+ y->set_op("Rsqrt");
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ }
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.hoist_unary_out_of_concat)
pipeline.AddStage<HoistCWiseUnaryFromConcatStage>(ctx, ctx_ext);
+ if (options_.convert_sqrt_div_to_rsqrt_mul)
+ pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
bool remove_redundant_cast = true;
bool remove_negation = true;
bool hoist_unary_out_of_concat = false;
+ bool convert_sqrt_div_to_rsqrt_mul = false;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
+
void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_unary_out_of_concat = true;
}
+
+ void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
EXPECT_EQ(5, found);
}
+TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ Output sqrt_y = ops::Sqrt(s.WithOpName("sqrt_y"), y);
+ Output div_x_sqrt_y = ops::Div(s.WithOpName("output"), x, sqrt_y);
+
+ GrapplerItem item;
+ item.fetch = {"output"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlySqrtDivToRsqrtMul(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "output") {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("sqrt_y", node.input(1));
+ } else if (node.name() == "sqrt_y") {
+ EXPECT_EQ("Rsqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ }
+ }
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();