Turn trivial Pack ops with a single input into ExpandDims ops to avoid copying the...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 17:37:20 +0000 (10:37 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 17:42:39 +0000 (10:42 -0700)
PiperOrigin-RevId: 188742516

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index ca56833..53c177b 100644 (file)
@@ -217,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) {
   return op == "NextIteration" || op == "RefNextIteration";
 }
 
+bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
+
 bool IsPad(const NodeDef& node) {
   const auto& op = node.op();
   return op == "Pad" || op == "PadV2";
index a0946ee..cd5b464 100644 (file)
@@ -86,6 +86,7 @@ bool IsMod(const NodeDef& node);
 bool IsMul(const NodeDef& node);
 bool IsMatMul(const NodeDef& node);
 bool IsNextIteration(const NodeDef& node);
+bool IsPack(const NodeDef& node);
 bool IsPad(const NodeDef& node);
 bool IsNoOp(const NodeDef& node);
 bool IsNotEqual(const NodeDef& node);
index 39cc4a9..6cb0447 100644 (file)
@@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
 }
 
 Status ConstantFolding::SimplifyGraph(GraphDef* output,
-                                      const GraphProperties& properties,
+                                      GraphProperties* properties,
                                       bool use_shape_info) {
   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
   for (int i = 0; i < output->node_size(); ++i) {
@@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     if (use_shape_info &&
         (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) {
       const auto& shape =
-          properties.GetInputProperties(node->name())[0].shape();
+          properties->GetInputProperties(node->name())[0].shape();
       // The node is replaceable iff
       // unknown_rank == false && (dim_size == 0 || all dims have size 1)
       bool replaceable = !shape.unknown_rank();
@@ -1533,10 +1533,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     }
 
     if (use_shape_info && IsSlice(*node) &&
-        properties.GetInputProperties(node->name()).size() == 3) {
-      const auto& input = properties.GetInputProperties(node->name())[0];
-      const auto& b = properties.GetInputProperties(node->name())[1];
-      const auto& s = properties.GetInputProperties(node->name())[2];
+        properties->GetInputProperties(node->name()).size() == 3) {
+      const auto& input = properties->GetInputProperties(node->name())[0];
+      const auto& b = properties->GetInputProperties(node->name())[1];
+      const auto& s = properties->GetInputProperties(node->name())[2];
       if (TensorShape::IsValid(b.shape()) && b.has_value() &&
           TensorShape::IsValid(s.shape()) && s.has_value()) {
         Tensor begin(b.dtype(), b.shape());
@@ -1574,8 +1574,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     }
 
     if (IsTile(*node) &&
-        properties.GetInputProperties(node->name()).size() == 2) {
-      const auto& m = properties.GetInputProperties(node->name())[1];
+        properties->GetInputProperties(node->name()).size() == 2) {
+      const auto& m = properties->GetInputProperties(node->name())[1];
       if (TensorShape::IsValid(m.shape()) && m.has_value()) {
         Tensor multiplies(m.dtype(), m.shape());
         if (!multiplies.FromProto(m.value())) {
@@ -1602,8 +1602,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     }
 
     if (IsPad(*node) &&
-        properties.GetInputProperties(node->name()).size() >= 2) {
-      const auto& p = properties.GetInputProperties(node->name())[1];
+        properties->GetInputProperties(node->name()).size() >= 2) {
+      const auto& p = properties->GetInputProperties(node->name())[1];
       if (TensorShape::IsValid(p.shape()) && p.has_value()) {
         Tensor paddings(p.dtype(), p.shape());
         if (!paddings.FromProto(p.value())) {
@@ -1625,12 +1625,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     }
 
     if (use_shape_info && IsSqueeze(*node) &&
-        !properties.GetInputProperties(node->name()).empty()) {
+        !properties->GetInputProperties(node->name()).empty()) {
       // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
       // error to squeeze a dimension that is not 1, so we only need to check
       // whether the input has > 1 size for each dimension.
       const auto& shape =
-          properties.GetInputProperties(node->name())[0].shape();
+          properties->GetInputProperties(node->name())[0].shape();
       // The node is replaceable iff
       // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
       bool replaceable = !shape.unknown_rank();
@@ -1642,6 +1642,38 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
     }
 
+    if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+        !OptimizedNodeExists(*node, "_const_axis")) {
+      // Create constant axis node.
+      Tensor axis_t(DT_INT32, TensorShape({}));
+      NodeDef* axis_node = output->add_node();
+      axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+      const int axis = node->attr().at("axis").i();
+      if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+          !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
+               .ok()) {
+        continue;
+      }
+      VLOG(1) << "*** Rewriting trivial Pack node: " << node->DebugString();
+      // Add a control dependency to make sure axis_node is in the right frame.
+      const string ctrl_dep = ConstantFolding::AddControlDependency(
+          node->input(0), graph_, node_map_.get());
+      axis_node->add_input(ctrl_dep);
+      axis_node->set_device(node->device());
+      node->set_op("ExpandDims");
+      if (node->attr().count("axis") != 0) {
+        node->mutable_attr()->erase("axis");
+      }
+      if (node->attr().count("N") != 0) {
+        node->mutable_attr()->erase("N");
+      }
+      (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+      node->add_input(axis_node->name());
+      if (node->input_size() > 2) {
+        node->mutable_input()->SwapElements(1, node->input_size() - 1);
+      }
+    }
+
     // Switch(x, x) will always feed false to its false branch and true to
     // its true branch. By rewriting the graph a bit, we can propagate these
     // constants down the two output branches, and just use control dependencies
@@ -1759,7 +1791,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       graph_modified_ = true;
       continue;
     }
-    if (use_shape_info && IsSimplifiableReshape(*node, properties)) {
+    if (use_shape_info && IsSimplifiableReshape(*node, *properties)) {
       DataType output_type = node->attr().at("T").type();
       node->set_op("Identity");
       node->clear_attr();
@@ -1777,8 +1809,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
     // Simplify arithmetic operations with ones or zeros.
     if (use_shape_info &&
         (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
-        properties.HasInputProperties(node->name()) &&
-        properties.HasOutputProperties(node->name())) {
+        properties->HasInputProperties(node->name()) &&
+        properties->HasOutputProperties(node->name())) {
       const NodeDef* x = node_map_->GetNode(node->input(0));
       const NodeDef* y = node_map_->GetNode(node->input(1));
       if (x == nullptr || y == nullptr) {
@@ -1786,12 +1818,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
                                        node->DebugString());
       }
       const TensorShapeProto& output_shape =
-          properties.GetOutputProperties(node->name())[0].shape();
+          properties->GetOutputProperties(node->name())[0].shape();
 
       // Simplify element-wise multiplication by ones or addition/subtraction
       // of zeros.
       const TensorShapeProto& y_shape =
-          properties.GetInputProperties(node->name())[1].shape();
+          properties->GetInputProperties(node->name())[1].shape();
       const bool x_is_zero = IsZeros(*x);
       const bool x_is_one = IsOnes(*x);
       const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
@@ -1818,7 +1850,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
       }
 
       const TensorShapeProto& x_shape =
-          properties.GetInputProperties(node->name())[0].shape();
+          properties->GetInputProperties(node->name())[0].shape();
       const bool y_is_zero = IsZeros(*y);
       const bool y_is_one = IsOnes(*y);
       const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
@@ -2139,7 +2171,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   }
   TF_RETURN_IF_ERROR(FoldGraph(output));
   node_map_.reset(new NodeMap(output));
-  TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+  TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info));
 
   return Status::OK();
 }
index 2fd59c7..13ecfcd 100644 (file)
@@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer {
   bool IsSimplifiableReduction(const NodeDef& node) const;
   bool IsSimplifiableReshape(const NodeDef& node,
                              const GraphProperties& properties) const;
-  Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+  Status SimplifyGraph(GraphDef* output, GraphProperties* properties,
                        bool use_shape_info);
 
   Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
index f421a59..724fb84 100644 (file)
@@ -1930,6 +1930,48 @@ TEST_F(ConstantFoldingTest, IdenticalN) {
   EXPECT_EQ("^id_n", output.node(7).input(2));
 }
 
+TEST_F(ConstantFoldingTest, TrivialPack) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+  Output x =
+      ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT);
+  Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {});
+  auto stack =
+      ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
+                 ops::Stack::Axis(1));
+
+  GrapplerItem item;
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+  item.fetch.push_back("stack");
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  LOG(INFO) << output.DebugString();
+  EXPECT_EQ(5, output.node_size());
+  for (const auto& node : output.node()) {
+    if (node.name() == "stack") {
+      EXPECT_EQ("stack", node.name());
+      EXPECT_EQ("ExpandDims", node.op());
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
+      EXPECT_EQ("^y", node.input(2));
+    } else if (node.name() == "ConstantFolding/stack_const_axis") {
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^x", node.input(0));
+    }
+  }
+
+  std::vector<string> fetch = {"stack"};
+  auto tensors_expected = EvaluateNodes(item.graph, fetch);
+  auto tensors = EvaluateNodes(output, fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+  EXPECT_EQ(1, tensors.size());
+  EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow