Automated g4 rollback of changelist 192516190
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 12 Apr 2018 00:29:32 +0000 (17:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 12 Apr 2018 00:31:36 +0000 (17:31 -0700)
PiperOrigin-RevId: 192536085

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index cfe1329..9c45aed 100644 (file)
@@ -249,10 +249,6 @@ bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
 
 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
 
-bool IsRandomShuffle(const NodeDef& node) {
-  return node.op() == "RandomShuffle";
-}
-
 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
 
 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
@@ -302,7 +298,9 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
 
 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
 
-bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
+bool IsShuffle(const NodeDef& node) {
+  return node.op() == "Shuffle" || node.op() == "RandomShuffle";
+}
 
 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
 
index 0573b02..79fd05e 100644 (file)
@@ -98,7 +98,6 @@ bool IsPolygamma(const NodeDef& node);
 bool IsPrint(const NodeDef& node);
 bool IsProd(const NodeDef& node);
 bool IsPow(const NodeDef& node);
-bool IsRandomShuffle(const NodeDef& node);
 bool IsReal(const NodeDef& node);
 bool IsRealDiv(const NodeDef& node);
 bool IsRelu6Grad(const NodeDef& node);
index 17d8b74..b2a1ce6 100644 (file)
@@ -1574,99 +1574,24 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph,
       continue;
     }
 
-    // Remove Shuffle or Transpose op over dimensions of size 1.
-    if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
-        !properties->GetInputProperties(node->name()).empty()) {
-      const auto& shape =
-          properties->GetInputProperties(node->name())[0].shape();
-      if (shape.unknown_rank()) {
-        // Not optimizable.
-        continue;
-      }
-      const auto& p = properties->GetInputProperties(node->name())[1];
-      if (TensorShape::IsValid(p.shape()) && p.has_value()) {
-        Tensor perm(p.dtype(), p.shape());
-        if (!perm.FromProto(p.value())) {
-          return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                         p.value().DebugString());
-        }
-        std::vector<int> permutation;
-        for (int j = 0; j < perm.NumElements(); ++j) {
-          if (perm.dtype() == DT_INT64) {
-            permutation.push_back(perm.vec<int64>()(j));
-          } else {
-            permutation.push_back(perm.vec<int>()(j));
-          }
-        }
-        if (permutation.size() != shape.dim_size()) {
-          // Number of elements in perm should be same as dim_size. Skip if not.
-          continue;
-        }
-        // The node is replaceable iff
-        // dim_size == 0 || all dims have size 1 ||
-        // all dims with > 1 size are not permuted.
-        bool replaceable = true;
-        for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-          replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
-        }
-        if (replaceable) {
-          ReplaceOperationWithIdentity(0, node, optimized_graph);
-          continue;
-        }
-      }
-    }
-
-    // Remove RandomShuffle op if it is scalar or first dimension is of size 1.
-    if (use_shape_info && IsRandomShuffle(*node) &&
-        !properties->GetInputProperties(node->name()).empty()) {
+    // Remove Shuffle or Reverse op over scalar values.
+    if (use_shape_info &&
+        !properties->GetInputProperties(node->name()).empty() &&
+        (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
       const auto& shape =
           properties->GetInputProperties(node->name())[0].shape();
       // The node is replaceable iff
-      // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
-      if (!shape.unknown_rank() &&
-          (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
+      // unknown_rank == false && (dim_size == 0 || all dims have size 1)
+      bool replaceable = !shape.unknown_rank();
+      for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
+        replaceable &= shape.dim(j).size() == 1;
+      }
+      if (replaceable) {
         ReplaceOperationWithIdentity(0, node, optimized_graph);
         continue;
       }
     }
 
-    // Remove Reverse op over dimensions with size 1.
-    if (use_shape_info && IsReverse(*node) &&
-        !properties->GetInputProperties(node->name()).empty()) {
-      const auto& shape =
-          properties->GetInputProperties(node->name())[0].shape();
-      const auto& a = properties->GetInputProperties(node->name())[1];
-      if (TensorShape::IsValid(a.shape()) && a.has_value()) {
-        Tensor axis(a.dtype(), a.shape());
-        if (!axis.FromProto(a.value())) {
-          return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                         a.value().DebugString());
-        }
-        std::set<int> target_axes;
-        for (int j = 0; j < axis.NumElements(); ++j) {
-          if (axis.dtype() == DT_INT64) {
-            target_axes.insert(axis.vec<int64>()(j));
-          } else {
-            target_axes.insert(axis.vec<int>()(j));
-          }
-        }
-
-        // The node is replaceable iff
-        // unknown_rank == false &&
-        // (dim_size == 0 || all dims have size 1 ||
-        //  all dims with > 1 size are not in target_axes)
-        bool replaceable = !shape.unknown_rank();
-        for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
-          replaceable &= shape.dim(j).size() == 1 ||
-                         target_axes.find(j) == target_axes.end();
-        }
-        if (replaceable) {
-          ReplaceOperationWithIdentity(0, node, optimized_graph);
-          continue;
-        }
-      }
-    }
-
     if (use_shape_info && IsSlice(*node) &&
         properties->GetInputProperties(node->name()).size() == 3) {
       const auto& input = properties->GetInputProperties(node->name())[0];
index 7453fb6..31abe43 100644 (file)
@@ -1389,6 +1389,8 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
   ops::SplitV s1(scope.WithOpName("s1"), in1, size_splits1, split_dim, 1);
   ops::SplitV s2(scope.WithOpName("s2"), in2, size_splits2, split_dim, 2);
 
+  LOG(INFO) << s1.output.size();
+  LOG(INFO) << s2.output.size();
   ops::Add out(scope.WithOpName("out"), s1[0], s2[0]);
 
   GrapplerItem item;
@@ -1416,45 +1418,7 @@ TEST_F(ConstantFoldingTest, SplitVRemoval) {
   CompareGraphs(want, got);
 }
 
-TEST_F(ConstantFoldingTest, TransposeOnSize1DimsRemoval) {
-  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
-  Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
-                             DT_FLOAT);
-  Output p1 = ops::Const(scope.WithOpName("p1"), {3, 2, 1, 0}, {4});
-  Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 4, 2, 1}),
-                             DT_FLOAT);
-  Output p2 = ops::Const(scope.WithOpName("p2"), {3, 1, 2, 0}, {4});
-  ops::Transpose t1(scope.WithOpName("t1"), in1, p1);
-  ops::Transpose t2(scope.WithOpName("t2").WithControlDependencies({in1}), in2,
-                    p2);
-
-  ops::Add out1(scope.WithOpName("out1"), t1, t2);
-
-  GrapplerItem item;
-  item.fetch = {"out1"};
-  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
-  ConstantFolding optimizer(nullptr /* cpu_device */);
-  GraphDef got;
-  Status status = optimizer.Optimize(nullptr, item, &got);
-  TF_EXPECT_OK(status);
-
-  GraphDef want;
-  AddNode("in1", "VariableV2", {}, {}, &want);
-  AddNode("in2", "VariableV2", {}, {}, &want);
-  AddNode("p1", "Const", {}, {}, &want);
-  AddNode("p2", "Const", {}, {}, &want);
-  AddNode("t1", "Transpose", {"in1", "p1"}, {}, &want);
-  AddNode("t2", "Identity",
-          {"in2", AsControlDependency("in1"), AsControlDependency("p2")}, {},
-          &want);
-  AddNode("out1", "Add", {"t1", "t2"}, {}, &want);
-
-  CompareGraphs(want, got);
-}
-
-TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
+TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
 
   Output in1 =
@@ -1488,44 +1452,6 @@ TEST_F(ConstantFoldingTest, RandomShuffleOnScalarRemoval) {
   CompareGraphs(want, got);
 }
 
-TEST_F(ConstantFoldingTest, ReverseOnSize1DimsRemoval) {
-  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
-  Output in1 = ops::Variable(scope.WithOpName("in1"), TensorShape({1, 2, 4, 1}),
-                             DT_FLOAT);
-  Output a1 = ops::Const(scope.WithOpName("a1"), {3, 2, 1, 0}, {4});
-  Output in2 = ops::Variable(scope.WithOpName("in2"), TensorShape({1, 2, 4, 1}),
-                             DT_FLOAT);
-  Output a2 = ops::Const(scope.WithOpName("a2"), {0, 3}, {2});
-  ops::Reverse r1(scope.WithOpName("r1"), in1, a1);
-  ops::Reverse r2(scope.WithOpName("r2").WithControlDependencies({in1}), in2,
-                  a2);
-
-  ops::Add out1(scope.WithOpName("out1"), r1, r2);
-
-  GrapplerItem item;
-  item.fetch = {"out1"};
-  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
-
-  ConstantFolding optimizer(nullptr /* cpu_device */);
-  GraphDef got;
-  Status status = optimizer.Optimize(nullptr, item, &got);
-  TF_EXPECT_OK(status);
-
-  GraphDef want;
-  AddNode("in1", "VariableV2", {}, {}, &want);
-  AddNode("in2", "VariableV2", {}, {}, &want);
-  AddNode("a1", "Const", {}, {}, &want);
-  AddNode("a2", "Const", {}, {}, &want);
-  AddNode("r1", "ReverseV2", {"in1", "a1"}, {}, &want);
-  AddNode("r2", "Identity",
-          {"in2", AsControlDependency("in1"), AsControlDependency("a2")}, {},
-          &want);
-  AddNode("out1", "Add", {"r1", "r2"}, {}, &want);
-
-  CompareGraphs(want, got);
-}
-
 TEST_F(ConstantFoldingTest, SliceWithSameDimensionRemoval) {
   {  // size = {3, 5}
     tensorflow::Scope scope = tensorflow::Scope::NewRootScope();