From ea9e65c94ad71ca86d2be91c4109c62269b42cf8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 13 Mar 2018 10:39:33 -0700 Subject: [PATCH] Enable arithmetic optimizations for Fill nodes that are all zeros or ones. PiperOrigin-RevId: 188893722 --- .../core/grappler/optimizers/constant_folding.cc | 12 ++++- .../grappler/optimizers/constant_folding_test.cc | 58 ++++++++++++++-------- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index a4d8376..21037ff 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1377,6 +1377,10 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { if (node.op() == "OnesLike") { return true; } + if (node.op() == "Fill") { + NodeDef* values = node_map_->GetNode(NodeName(node.input(1))); + return values != nullptr && IsOnes(*values); + } if (node.op() != "Const") { return false; } @@ -1408,6 +1412,10 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { if (node.op() == "ZerosLike") { return true; } + if (node.op() == "Fill") { + NodeDef* values = node_map_->GetNode(NodeName(node.input(1))); + return values != nullptr && IsZeros(*values); + } if (!IsConstant(node)) { return false; } @@ -1846,7 +1854,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const TensorShapeProto& y_shape = properties->GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); - const bool x_is_one = IsOnes(*x); + const bool x_is_one = x_is_zero ? false : IsOnes(*x); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { @@ -1873,7 +1881,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const TensorShapeProto& x_shape = properties->GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); - const bool y_is_one = IsOnes(*y); + const bool y_is_one = y_is_zero ? false : IsOnes(*y); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 724fb84..cf151d4 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -152,7 +152,10 @@ TEST_F(ConstantFoldingTest, AddTree) { } TEST_F(ConstantFoldingTest, NeutralElement) { - for (bool use_const : {true, false}) { + int kConst = 0; + int kLike = 1; + int kFill = 2; + for (int const_type : {kConst, kLike, kFill}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, ops::Placeholder::Shape(TensorShape({2, 2}))); @@ -164,11 +167,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) { ops::Placeholder::Shape(TensorShape({2, 3}))); Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT, ops::Placeholder::Shape(TensorShape({2}))); - Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x) - : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2}); Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2}); - Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x) - : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2}); + Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2}); + Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x); + Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f); + Output zeros = const_type == kConst + ? zeros_const + : (const_type == kLike ? zeros_like : zeros_fill); + Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2}); + Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x); + Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f); + Output ones = const_type == kConst + ? ones_const + : (const_type == kLike ? ones_like : ones_fill); Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); @@ -201,6 +212,13 @@ TEST_F(ConstantFoldingTest, NeutralElement) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + const string suffix = + (const_type == kConst ? "_const" + : (const_type == kLike ? "_like" : "_fill")); + const string zeros_name = strings::StrCat("zeros", suffix); + const string ones_name = strings::StrCat("ones", suffix); + const string ctrl_zeros_name = strings::StrCat("^zeros", suffix); + const string ctrl_ones_name = strings::StrCat("^ones", suffix); EXPECT_EQ(28, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); @@ -208,19 +226,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) { if (name == "mul1") { EXPECT_EQ("Const", node.op()); EXPECT_EQ("^x", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } else if (name == "mul2") { EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^zeros", node.input(0)); + EXPECT_EQ(ctrl_zeros_name, node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "mul3") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("^ones", node.input(1)); + EXPECT_EQ(ctrl_ones_name, node.input(1)); } else if (name == "mul4") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); - EXPECT_EQ("^ones", node.input(1)); + EXPECT_EQ(ctrl_ones_name, node.input(1)); } else if (name == "mul5") { EXPECT_EQ("Const", node.op()); EXPECT_EQ("^x", node.input(0)); @@ -232,23 +250,23 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } else if (name == "div1") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("^ones", node.input(1)); + EXPECT_EQ(ctrl_ones_name, node.input(1)); } else if (name == "div2") { EXPECT_EQ("Reciprocal", node.op()); EXPECT_EQ("y", node.input(0)); - EXPECT_EQ("^ones", node.input(1)); + EXPECT_EQ(ctrl_ones_name, node.input(1)); } else if (name == "matmul1") { EXPECT_EQ("Const", node.op()); EXPECT_EQ("^x", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } else if (name == "matmul2") { EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^zeros", node.input(0)); + EXPECT_EQ(ctrl_zeros_name, node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "matmul3") { EXPECT_EQ("Const", node.op()); EXPECT_EQ("^a", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); TensorProto t = node.attr().at("value").tensor(); EXPECT_EQ(1, t.float_val_size()); EXPECT_EQ(0, t.float_val(0)); @@ -257,7 +275,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ(2, t.tensor_shape().dim(1).size()); } else if (name == "matmul4") { EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^zeros", node.input(0)); + EXPECT_EQ(ctrl_zeros_name, node.input(0)); EXPECT_EQ("^b", node.input(1)); TensorProto t = node.attr().at("value").tensor(); EXPECT_EQ(1, t.float_val_size()); @@ -268,11 +286,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } else if (name == "add1") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } else if (name == "add2") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } else if (name == "bias_add1") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); @@ -280,16 +298,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } else if (name == "bias_add2") { // We don't eliminate this one, because it requires broadcasting. EXPECT_EQ("BiasAdd", node.op()); - EXPECT_EQ("zeros", node.input(0)); + EXPECT_EQ(zeros_name, node.input(0)); EXPECT_EQ("bias", node.input(1)); } else if (name == "sub1") { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } else if (name == "sub2") { EXPECT_EQ("Neg", node.op()); EXPECT_EQ("y", node.input(0)); - EXPECT_EQ("^zeros", node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(1)); } const std::set square_zero_const{"mul1", "mul2", "mul5", "mul6", "matmul1", "matmul2"}; -- 2.7.4