Automated g4 rollback of changelist 192536085
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Apr 2018 18:24:26 +0000 (11:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 26 Apr 2018 18:26:13 +0000 (11:26 -0700)
PiperOrigin-RevId: 194426650

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index f595cf6..c024303 100644 (file)
@@ -250,6 +250,10 @@ bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
 
 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
 
+bool IsRandomShuffle(const NodeDef& node) {
+  return node.op() == "RandomShuffle";
+}
+
 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
 
 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
@@ -299,9 +303,7 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
 
 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
 
-bool IsShuffle(const NodeDef& node) {
-  return node.op() == "Shuffle" || node.op() == "RandomShuffle";
-}
+bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
 
 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
 
index b25ba19..3cba6b8 100644 (file)
@@ -98,6 +98,7 @@ bool IsPolygamma(const NodeDef& node);
 bool IsPrint(const NodeDef& node);
 bool IsProd(const NodeDef& node);
 bool IsPow(const NodeDef& node);
+bool IsRandomShuffle(const NodeDef& node);
 bool IsReal(const NodeDef& node);
 bool IsRealDiv(const NodeDef& node);
 bool IsRelu6Grad(const NodeDef& node);
index 45bb188..4801f18 100644 (file)
@@ -1575,24 +1575,106 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
       continue;
     }
 
-    // Remove Shuffle or Reverse op over scalar values.
-    if (use_shape_info &&
-        !properties->GetInputProperties(node->name()).empty() &&
-        (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
+    // Remove Shuffle or Transpose op over dimensions of size 1.
+    if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
+        properties->GetInputProperties(node->name()).size() >= 2) {
       const auto& shape =
           properties->GetInputProperties(node->name())[0].shape();
-      // The node is replaceable iff
-      // unknown_rank == false && (dim_size == 0 || all dims have size 1)
-      bool replaceable = !shape.unknown_rank();
-      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-        replaceable &= shape.dim(j).size() == 1;
+      if (shape.unknown_rank()) {
+        // Not optimizable.
+        continue;
       }
-      if (replaceable) {
+      const auto& p = properties->GetInputProperties(node->name())[1];
+      if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+        Tensor perm(p.dtype(), p.shape());
+        if (!perm.FromProto(p.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         p.value().DebugString());
+        }
+        std::vector<int> permutation;
+        for (int j = 0; j < perm.NumElements(); ++j) {
+          if (perm.dtype() == DT_INT64) {
+            permutation.push_back(perm.vec<int64>()(j));
+          } else {
+            permutation.push_back(perm.vec<int>()(j));
+          }
+        }
+        if (permutation.size() != shape.dim_size()) {
+          // Number of elements in perm should be same as dim_size. Skip if not.
+          continue;
+        }
+        // The node is replaceable iff
+        // dim_size == 0 || all dims have size 1 ||
+        // all dims with > 1 size are not permuted.
+        bool replaceable = true;
+        for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+          replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
+        }
+        if (replaceable) {
+          ReplaceOperationWithIdentity(0, node, optimized_graph);
+          continue;
+        }
+      }
+    }
+
+    // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
+    if (use_shape_info && IsRandomShuffle(*node) &&
+        !properties->GetInputProperties(node->name()).empty()) {
+      const auto& shape =
+          properties->GetInputProperties(node->name())[0].shape();
+      // The node is replaceable iff
+      // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
+      if (!shape.unknown_rank() &&
+          (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
         ReplaceOperationWithIdentity(0, node, optimized_graph);
         continue;
       }
     }
 
+    // Remove Reverse op over dimensions with size 1.
+    if (use_shape_info && node->op() == "ReverseV2" &&
+        properties->GetInputProperties(node->name()).size() >= 2) {
+      const auto& shape =
+          properties->GetInputProperties(node->name())[0].shape();
+      if (shape.unknown_rank()) {
+        // Not optimizable.
+        continue;
+      }
+      const auto& a = properties->GetInputProperties(node->name())[1];
+      if (TensorShape::IsValid(a.shape()) && a.has_value()) {
+        Tensor axis(a.dtype(), a.shape());
+        if (!axis.FromProto(a.value())) {
+          return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                         a.value().DebugString());
+        }
+        std::set<int> target_axes;
+        for (int j = 0; j < axis.NumElements(); ++j) {
+          // value of axis can be negative.
+          if (axis.dtype() == DT_INT64) {
+            target_axes.insert(
+                (axis.vec<int64>()(j) + shape.dim_size()) % shape.dim_size());
+          } else {
+            target_axes.insert(
+                (axis.vec<int>()(j) + shape.dim_size()) % shape.dim_size());
+          }
+        }
+
+        // The node is replaceable iff
+        // unknown_rank == false &&
+        // (dim_size == 0 || all dims have size 1 ||
+        //  all dims with > 1 size are not in target_axes)
+        bool replaceable = !shape.unknown_rank();
+        for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+          replaceable &= shape.dim(j).size() == 1 ||
+                         target_axes.find(j) == target_axes.end();
+        }
+        if (replaceable) {
+          ReplaceOperationWithIdentity(0, node, optimized_graph);
+          continue;
+        }
+      }
+    }
+
     if (use_shape_info && IsSlice(*node) &&
         properties->GetInputProperties(node->name()).size() == 3) {
       const auto& input = properties->GetInputProperties(node->name())[0];
index 25693c5..306ddd2 100644 (file)
@@ -1522,8 +1522,6 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
   ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
   ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
 
-  LOG(INFO) << s1.output.size();
-  LOG(INFO) << s2.output.size();
   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
 
   GrapplerItem item;
@@ -1561,7 +1559,45 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
   test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
 }
 
-TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
+TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+  Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
+                             DT_FLOAT);
+  Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
+  Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
+                             DT_FLOAT);
+  Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
+  ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
+  ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
+                    p2);
+
+  ops::Add out1(scope.WithOpName("out1"), t1, t2);
+
+  GrapplerItem item;
+  item.fetch = {"out1"};
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  ConstantFolding optimizer(nullptr /* cpu_device */);
+  GraphDef got;
+  Status status = optimizer.Optimize(nullptr, item, &got);
+  TF_EXPECT_OK(status);
+
+  GraphDef want;
+  AddNode("in1", "VariableV2", {}, {}, &want);
+  AddNode("in2", "VariableV2", {}, {}, &want);
+  AddNode("p1", "Const", {}, {}, &want);
+  AddNode("p2", "Const", {}, {}, &want);
+  AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
+  AddNode("t2", "Identity",
+          {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
+          &want);
+  AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
+
+  CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
 
   Output in1 =
@@ -1606,6 +1642,44 @@ TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
     test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
 }
 
+TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+  Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
+                             DT_FLOAT);
+  Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
+  Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
+                             DT_FLOAT);
+  Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
+  ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
+  ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
+                  a2);
+
+  ops::Add out1(scope.WithOpName("out1"), r1, r2);
+
+  GrapplerItem item;
+  item.fetch = {"out1"};
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  ConstantFolding optimizer(nullptr /* cpu_device */);
+  GraphDef got;
+  Status status = optimizer.Optimize(nullptr, item, &got);
+  TF_EXPECT_OK(status);
+
+  GraphDef want;
+  AddNode("in1", "VariableV2", {}, {}, &want);
+  AddNode("in2", "VariableV2", {}, {}, &want);
+  AddNode("a1", "Const", {}, {}, &want);
+  AddNode("a2", "Const", {}, {}, &want);
+  AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
+  AddNode("r2", "Identity",
+          {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
+          &want);
+  AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
+
+  CompareGraphs(want, got);
+}
+
 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
   {  // size = {3, 5}
     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();