Turn the following ops into Identity.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 10 Mar 2018 20:03:19 +0000 (12:03 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 10 Mar 2018 20:07:01 +0000 (12:07 -0800)
 * Slice when the Size input matches the size of the input tensor
 * Tile when the multiples input is a tensor of '1'
 * Pad/PadV2 when the paddings input is a tensor of 0
 * Squeeze when the squeeze dimensions are known to be > 1

PiperOrigin-RevId: 188609800

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 31dc1b7..39cc4a9 100644 (file)
@@ -1524,7 +1524,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       // The node is replaceable iff
       // unknown_rank == false && (dim_size == 0 || all dims have size 1)
       bool replaceable = !shape.unknown_rank();
-      for (int j = 0; j < shape.dim_size(); ++j) {
+      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
         replaceable &= shape.dim(j).size() == 1;
       }
       if (replaceable) {
@@ -1532,6 +1532,116 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
     }
 
+    if (use_shape_info && IsSlice(*node) &&
+        properties.GetInputProperties(node->name()).size() == 3) {
+      const auto& input = properties.GetInputProperties(node->name())[0];
+      const auto& b = properties.GetInputProperties(node->name())[1];
+      const auto& s = properties.GetInputProperties(node->name())[2];
+      if (TensorShape::IsValid(b.shape()) && b.has_value() &&
+          TensorShape::IsValid(s.shape()) && s.has_value()) {
+        Tensor begin(b.dtype(), b.shape());
+        if (!begin.FromProto(b.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         b.value().DebugString());
+        }
+        Tensor size(s.dtype(), s.shape());
+        if (!size.FromProto(s.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         s.value().DebugString());
+        }
+        // The node is replaceable iff unknown_rank == false &&
+        // begin == 0 && (size == -1 || size == input_shape) for all dimensions
+        bool replaceable = !input.shape().unknown_rank();
+        for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
+          if (begin.dtype() == DT_INT32) {
+            replaceable &= begin.vec<int>()(j) == 0;
+          } else {
+            replaceable &= begin.vec<int64>()(j) == 0;
+          }
+          if (size.dtype() == DT_INT32) {
+            replaceable &= (size.vec<int>()(j) == -1 ||
+                            size.vec<int>()(j) == input.shape().dim(j).size());
+          } else {
+            replaceable &=
+                (size.vec<int64>()(j) == -1 ||
+                 size.vec<int64>()(j) == input.shape().dim(j).size());
+          }
+        }
+        if (replaceable) {
+          ReplaceOperationWithIdentity(0, node, output);
+        }
+      }
+    }
+
+    if (IsTile(*node) &&
+        properties.GetInputProperties(node->name()).size() == 2) {
+      const auto& m = properties.GetInputProperties(node->name())[1];
+      if (TensorShape::IsValid(m.shape()) && m.has_value()) {
+        Tensor multiplies(m.dtype(), m.shape());
+        if (!multiplies.FromProto(m.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         m.value().DebugString());
+        }
+        // The node is replaceable iff all values in multiplies are 1.
+        bool replaceable = true;
+        if (multiplies.dtype() == DT_INT32) {
+          for (int j = 0; replaceable && j < multiplies.vec<int>().size();
+               ++j) {
+            replaceable &= multiplies.vec<int>()(j) == 1;
+          }
+        } else {
+          for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
+               ++j) {
+            replaceable &= multiplies.vec<int64>()(j) == 1;
+          }
+        }
+        if (replaceable) {
+          ReplaceOperationWithIdentity(0, node, output);
+        }
+      }
+    }
+
+    if (IsPad(*node) &&
+        properties.GetInputProperties(node->name()).size() >= 2) {
+      const auto& p = properties.GetInputProperties(node->name())[1];
+      if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+        Tensor paddings(p.dtype(), p.shape());
+        if (!paddings.FromProto(p.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         p.value().DebugString());
+        }
+        // The node is replaceable iff all values in paddings are 0.
+        bool replaceable = true;
+        // The operation requires it to be int32 value so we don't check for
+        // 1nt64.
+        const auto flatten = paddings.flat<int32>();
+        for (int j = 0; replaceable && j < flatten.size(); ++j) {
+          replaceable &= flatten(j) == 0;
+        }
+        if (replaceable) {
+          ReplaceOperationWithIdentity(0, node, output);
+        }
+      }
+    }
+
+    if (use_shape_info && IsSqueeze(*node) &&
+        !properties.GetInputProperties(node->name()).empty()) {
+      // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
+      // error to squeeze a dimension that is not 1, so we only need to check
+      // whether the input has > 1 size for each dimension.
+      const auto& shape =
+          properties.GetInputProperties(node->name())[0].shape();
+      // The node is replaceable iff
+      // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
+      bool replaceable = !shape.unknown_rank();
+      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+        replaceable &= shape.dim(j).size() > 1;
+      }
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, node, output);
+      }
+    }
+
     // Switch(x, x) will always feed false to its false branch and true to
     // its true branch. By rewriting the graph a bit, we can propagate these
     // constants down the two output branches, and just use control dependencies
@@ -2027,7 +2137,6 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
     TF_RETURN_IF_ERROR(MaterializeShapes(properties));
     TF_RETURN_IF_ERROR(MaterializeConstants(properties));
   }
-
   TF_RETURN_IF_ERROR(FoldGraph(output));
   node_map_.reset(new NodeMap(output));
   TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
index 4b97708..f421a59 100644 (file)
@@ -1261,6 +1261,187 @@ TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
   CompareGraphs(want, got);
 }
 
+TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
+  {  // size = {3, 5}
+    tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+    auto in1 = ops::Variable(scope.WithOpName("in1"), {3, 5}, DT_FLOAT);
+    auto begin = ops::Const(scope.WithOpName("begin"), {0, 0}, {2});
+    auto size = ops::Const(scope.WithOpName("size"), {3, 5}, {2});
+    Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+    ops::Slice s1(scope.WithOpName("s1"), in1, begin, size);
+    ops::Slice s2(scope.WithOpName("s2"), in2, begin, size);
+
+    ops::Add out(scope.WithOpName("out"), s1, s2);
+
+    GrapplerItem item;
+    item.fetch = {"out"};
+    TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+    ConstantFolding fold(nullptr /* cpu_device */);
+    GraphDef got;
+    Status status = fold.Optimize(nullptr, item, &got);
+    TF_EXPECT_OK(status);
+
+    GraphDef want;
+    AddNode("in1", "VariableV2", {}, &want);
+    AddNode("in2", "VariableV2", {}, &want);
+    AddNode("begin", "Const", {}, &want);
+    AddNode("size", "Const", {}, &want);
+    AddNode("s1", "Identity",
+            {"in1", AsControlDependency("begin"), AsControlDependency("size")},
+            &want);
+    AddNode("s2", "Slice", {"in2", "begin", "size"}, &want);
+    AddNode("out", "Add", {"s1", "s2"}, &want);
+
+    CompareGraphs(want, got);
+  }
+  {  // size = {-1, -1}
+    tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+    auto in1 =
+        ops::Variable(scope.WithOpName("in1"), {3, 5}, DataType::DT_FLOAT);
+    auto begin1 = ops::Const(scope.WithOpName("begin1"), {0, 0}, {2});
+    auto begin2 = ops::Const(scope.WithOpName("begin2"), {1, 1}, {2});
+    auto size = ops::Const(scope.WithOpName("size"), {-1, -1}, {2});
+    Output in2 =
+        ops::Variable(scope.WithOpName("in2"), {4, 6}, DataType::DT_FLOAT);
+    ops::Slice s1(scope.WithOpName("s1"), in1, begin1, size);
+    ops::Slice s2(scope.WithOpName("s2"), in2, begin2, size);
+
+    ops::Add out(scope.WithOpName("out"), s1, s2);
+
+    GrapplerItem item;
+    item.fetch = {"out"};
+    TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+    ConstantFolding fold(nullptr /* cpu_device */);
+    GraphDef got;
+    Status status = fold.Optimize(nullptr, item, &got);
+    TF_EXPECT_OK(status);
+
+    GraphDef want;
+    AddNode("in1", "VariableV2", {}, &want);
+    AddNode("in2", "VariableV2", {}, &want);
+    AddNode("begin1", "Const", {}, &want);
+    AddNode("begin2", "Const", {}, &want);
+    AddNode("size", "Const", {}, &want);
+    AddNode("s1", "Identity",
+            {"in1", AsControlDependency("begin1"), AsControlDependency("size")},
+            &want);
+    AddNode("s2", "Slice", {"in2", "begin2", "size"}, &want);
+    AddNode("out", "Add", {"s1", "s2"}, &want);
+
+    CompareGraphs(want, got);
+  }
+}
+
+TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+  auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+  auto in2 = ops::Variable(scope.WithOpName("in2"), {4, 3}, DT_FLOAT);
+  auto multiplies1 = ops::Const(scope.WithOpName("multiplies1"), {1, 1}, {2});
+  auto multiplies2 = ops::Const(scope.WithOpName("multiplies2"), {1, 2}, {2});
+
+  ops::Tile t1(scope.WithOpName("t1"), in1, multiplies1);
+  ops::Tile t2(scope.WithOpName("t2"), in2, multiplies2);
+
+  ops::Add out(scope.WithOpName("out"), t1, t2);
+
+  GrapplerItem item;
+  item.fetch = {"out"};
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef got;
+  Status status = fold.Optimize(nullptr, item, &got);
+  TF_EXPECT_OK(status);
+
+  GraphDef want;
+  AddNode("in1", "VariableV2", {}, &want);
+  AddNode("in2", "VariableV2", {}, &want);
+  AddNode("multiplies1", "Const", {}, &want);
+  AddNode("multiplies2", "Const", {}, &want);
+  AddNode("t1", "Identity", {"in1", AsControlDependency("multiplies1")}, &want);
+  AddNode("t2", "Tile", {"in2", "multiplies2"}, &want);
+  AddNode("out", "Add", {"t1", "t2"}, &want);
+
+  CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+  auto in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_INT32);
+  auto in2 = ops::Variable(scope.WithOpName("in2"), {2, 2}, DT_INT32);
+  auto paddings1 =
+      ops::Const(scope.WithOpName("paddings1"), {0, 0, 0, 0}, {2, 2});
+  auto paddings2 =
+      ops::Const(scope.WithOpName("paddings2"), {1, 1, 2, 2}, {2, 2});
+  auto c1 = ops::Const(scope.WithOpName("c1"), 1);
+  auto c2 = ops::Const(scope.WithOpName("c2"), 1);
+
+  ops::PadV2 p1(scope.WithOpName("p1"), in1, paddings1, c1);
+  ops::PadV2 p2(scope.WithOpName("p2"), in2, paddings2, c2);
+
+  ops::Add out(scope.WithOpName("out"), p1, p2);
+
+  GrapplerItem item;
+  item.fetch = {"out"};
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef got;
+  Status status = fold.Optimize(nullptr, item, &got);
+  TF_EXPECT_OK(status);
+
+  GraphDef want;
+  AddNode("in1", "VariableV2", {}, &want);
+  AddNode("in2", "VariableV2", {}, &want);
+  AddNode("paddings1", "Const", {}, &want);
+  AddNode("paddings2", "Const", {}, &want);
+  AddNode("c1", "Const", {}, &want);
+  AddNode("c2", "Const", {}, &want);
+  AddNode("p1", "Identity",
+          {"in1", AsControlDependency("paddings1"), AsControlDependency("c1")},
+          &want);
+  AddNode("p2", "PadV2", {"in2", "paddings2", "c2"}, &want);
+  AddNode("out", "Add", {"p1", "p2"}, &want);
+
+  CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, SqueezeWithAllDimesionsGreaterThanOne) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+  auto in1 = ops::Variable(scope.WithOpName("in1"), {2, 3}, DT_INT32);
+  auto in2 = ops::Variable(scope.WithOpName("in2"), {1, 2, 3, 1}, DT_INT32);
+
+  ops::Squeeze s1(scope.WithOpName("s1"), in1);
+  ops::Squeeze s2(scope.WithOpName("s2"), in2);
+
+  ops::Add out(scope.WithOpName("out"), s1, s2);
+
+  GrapplerItem item;
+  item.fetch = {"out"};
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef got;
+  Status status = fold.Optimize(nullptr, item, &got);
+  TF_EXPECT_OK(status);
+
+  GraphDef want;
+  AddNode("in1", "VariableV2", {}, &want);
+  AddNode("in2", "VariableV2", {}, &want);
+  AddNode("s1", "Identity", {"in1"}, &want);
+  AddNode("s2", "Squeeze", {"in2"}, &want);
+  AddNode("out", "Add", {"s1", "s2"}, &want);
+
+  CompareGraphs(want, got);
+}
+
 TEST_F(ConstantFoldingTest, NoOpReduction) {
   // Build a simple graph with a reduction that can be reduced to the
   // identity.