Enable arithmetic optimizations for Fill nodes that are all zeros or ones.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 13 Mar 2018 17:39:33 +0000 (10:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 13 Mar 2018 17:43:41 +0000 (10:43 -0700)
PiperOrigin-RevId: 188893722

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index a4d8376..21037ff 100644 (file)
@@ -1377,6 +1377,10 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const {
   if (node.op() == "OnesLike") {
     return true;
   }
+  if (node.op() == "Fill") {
+    NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+    return values != nullptr && IsOnes(*values);
+  }
   if (node.op() != "Const") {
     return false;
   }
@@ -1408,6 +1412,10 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
   if (node.op() == "ZerosLike") {
     return true;
   }
+  if (node.op() == "Fill") {
+    NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
+    return values != nullptr && IsZeros(*values);
+  }
   if (!IsConstant(node)) {
     return false;
   }
@@ -1846,7 +1854,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const TensorShapeProto& y_shape =
           properties->GetInputProperties(node->name())[1].shape();
       const bool x_is_zero = IsZeros(*x);
-      const bool x_is_one = IsOnes(*x);
+      const bool x_is_one = x_is_zero ? false : IsOnes(*x);
       const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
       if (y_matches_output_shape &&
           ((is_mul && x_is_one) || (is_add && x_is_zero))) {
@@ -1873,7 +1881,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       const TensorShapeProto& x_shape =
           properties->GetInputProperties(node->name())[0].shape();
       const bool y_is_zero = IsZeros(*y);
-      const bool y_is_one = IsOnes(*y);
+      const bool y_is_one = y_is_zero ? false : IsOnes(*y);
       const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
       if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
                                      ((is_add || is_sub) && y_is_zero))) {
index 724fb84..cf151d4 100644 (file)
@@ -152,7 +152,10 @@ TEST_F(ConstantFoldingTest, AddTree) {
 }
 
 TEST_F(ConstantFoldingTest, NeutralElement) {
-  for (bool use_const : {true, false}) {
+  int kConst = 0;
+  int kLike = 1;
+  int kFill = 2;
+  for (int const_type : {kConst, kLike, kFill}) {
     tensorflow::Scope s = tensorflow::Scope::NewRootScope();
     Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
                                 ops::Placeholder::Shape(TensorShape({2, 2})));
@@ -164,11 +167,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
                                 ops::Placeholder::Shape(TensorShape({2, 3})));
     Output bias = ops::Placeholder(s.WithOpName("bias"), DT_FLOAT,
                                    ops::Placeholder::Shape(TensorShape({2})));
-    Output zeros = !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
-                              : ops::Const(s.WithOpName("zeros"), 0.0f, {2, 2});
     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
-    Output ones = !use_const ? ops::OnesLike(s.WithOpName("ones"), x)
-                             : ops::Const(s.WithOpName("ones"), 1.0f, {2, 2});
+    Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
+    Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
+    Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
+    Output zeros = const_type == kConst
+                       ? zeros_const
+                       : (const_type == kLike ? zeros_like : zeros_fill);
+    Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
+    Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
+    Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
+    Output ones = const_type == kConst
+                      ? ones_const
+                      : (const_type == kLike ? ones_like : ones_fill);
     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
@@ -201,6 +212,13 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
 
+    const string suffix =
+        (const_type == kConst ? "_const"
+                              : (const_type == kLike ? "_like" : "_fill"));
+    const string zeros_name = strings::StrCat("zeros", suffix);
+    const string ones_name = strings::StrCat("ones", suffix);
+    const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
+    const string ctrl_ones_name = strings::StrCat("^ones", suffix);
     EXPECT_EQ(28, output.node_size());
     for (int i = 0; i < output.node_size(); ++i) {
       const NodeDef& node = output.node(i);
@@ -208,19 +226,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
       if (name == "mul1") {
         EXPECT_EQ("Const", node.op());
         EXPECT_EQ("^x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "mul2") {
         EXPECT_EQ("Const", node.op());
-        EXPECT_EQ("^zeros", node.input(0));
+        EXPECT_EQ(ctrl_zeros_name, node.input(0));
         EXPECT_EQ("^y", node.input(1));
       } else if (name == "mul3") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^ones", node.input(1));
+        EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "mul4") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("y", node.input(0));
-        EXPECT_EQ("^ones", node.input(1));
+        EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "mul5") {
         EXPECT_EQ("Const", node.op());
         EXPECT_EQ("^x", node.input(0));
@@ -232,23 +250,23 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
       } else if (name == "div1") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^ones", node.input(1));
+        EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "div2") {
         EXPECT_EQ("Reciprocal", node.op());
         EXPECT_EQ("y", node.input(0));
-        EXPECT_EQ("^ones", node.input(1));
+        EXPECT_EQ(ctrl_ones_name, node.input(1));
       } else if (name == "matmul1") {
         EXPECT_EQ("Const", node.op());
         EXPECT_EQ("^x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "matmul2") {
         EXPECT_EQ("Const", node.op());
-        EXPECT_EQ("^zeros", node.input(0));
+        EXPECT_EQ(ctrl_zeros_name, node.input(0));
         EXPECT_EQ("^y", node.input(1));
       } else if (name == "matmul3") {
         EXPECT_EQ("Const", node.op());
         EXPECT_EQ("^a", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
         TensorProto t = node.attr().at("value").tensor();
         EXPECT_EQ(1, t.float_val_size());
         EXPECT_EQ(0, t.float_val(0));
@@ -257,7 +275,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ(2, t.tensor_shape().dim(1).size());
       } else if (name == "matmul4") {
         EXPECT_EQ("Const", node.op());
-        EXPECT_EQ("^zeros", node.input(0));
+        EXPECT_EQ(ctrl_zeros_name, node.input(0));
         EXPECT_EQ("^b", node.input(1));
         TensorProto t = node.attr().at("value").tensor();
         EXPECT_EQ(1, t.float_val_size());
@@ -268,11 +286,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
       } else if (name == "add1") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "add2") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("y", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "bias_add1") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("x", node.input(0));
@@ -280,16 +298,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
       } else if (name == "bias_add2") {
         // We don't eliminate this one, because it requires broadcasting.
         EXPECT_EQ("BiasAdd", node.op());
-        EXPECT_EQ("zeros", node.input(0));
+        EXPECT_EQ(zeros_name, node.input(0));
         EXPECT_EQ("bias", node.input(1));
       } else if (name == "sub1") {
         EXPECT_EQ("Snapshot", node.op());
         EXPECT_EQ("x", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       } else if (name == "sub2") {
         EXPECT_EQ("Neg", node.op());
         EXPECT_EQ("y", node.input(0));
-        EXPECT_EQ("^zeros", node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(1));
       }
       const std::set<string> square_zero_const{"mul1", "mul2",    "mul5",
                                                "mul6", "matmul1", "matmul2"};