From 89177f289e9467e04b205a1a3e705ad67d9854d2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Mar 2018 10:37:20 -0700 Subject: [PATCH] Turn trivial Pack ops with a single input into ExpandDims ops to avoid copying the tensor. PiperOrigin-RevId: 188742516 --- tensorflow/core/grappler/op_types.cc | 2 + tensorflow/core/grappler/op_types.h | 1 + .../core/grappler/optimizers/constant_folding.cc | 70 ++++++++++++++++------ .../core/grappler/optimizers/constant_folding.h | 2 +- .../grappler/optimizers/constant_folding_test.cc | 42 +++++++++++++ 5 files changed, 97 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index ca56833..53c177b 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -217,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) { return op == "NextIteration" || op == "RefNextIteration"; } +bool IsPack(const NodeDef& node) { return node.op() == "Pack"; } + bool IsPad(const NodeDef& node) { const auto& op = node.op(); return op == "Pad" || op == "PadV2"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index a0946ee..cd5b464 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -86,6 +86,7 @@ bool IsMod(const NodeDef& node); bool IsMul(const NodeDef& node); bool IsMatMul(const NodeDef& node); bool IsNextIteration(const NodeDef& node); +bool IsPack(const NodeDef& node); bool IsPad(const NodeDef& node); bool IsNoOp(const NodeDef& node); bool IsNotEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 39cc4a9..6cb0447 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant( } Status ConstantFolding::SimplifyGraph(GraphDef* output, - const GraphProperties& properties, + GraphProperties* properties, bool use_shape_info) { const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; for (int i = 0; i < output->node_size(); ++i) { @@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, if (use_shape_info && (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) { const auto& shape = - properties.GetInputProperties(node->name())[0].shape(); + properties->GetInputProperties(node->name())[0].shape(); // The node is replaceable iff // unknown_rank == false && (dim_size == 0 || all dims have size 1) bool replaceable = !shape.unknown_rank(); @@ -1533,10 +1533,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } if (use_shape_info && IsSlice(*node) && - properties.GetInputProperties(node->name()).size() == 3) { - const auto& input = properties.GetInputProperties(node->name())[0]; - const auto& b = properties.GetInputProperties(node->name())[1]; - const auto& s = properties.GetInputProperties(node->name())[2]; + properties->GetInputProperties(node->name()).size() == 3) { + const auto& input = properties->GetInputProperties(node->name())[0]; + const auto& b = properties->GetInputProperties(node->name())[1]; + const auto& s = properties->GetInputProperties(node->name())[2]; if (TensorShape::IsValid(b.shape()) && b.has_value() && TensorShape::IsValid(s.shape()) && s.has_value()) { Tensor begin(b.dtype(), b.shape()); @@ -1574,8 +1574,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } if (IsTile(*node) && - properties.GetInputProperties(node->name()).size() == 2) { - const auto& m = properties.GetInputProperties(node->name())[1]; + properties->GetInputProperties(node->name()).size() == 2) { + const auto& m = properties->GetInputProperties(node->name())[1]; if (TensorShape::IsValid(m.shape()) && m.has_value()) { Tensor multiplies(m.dtype(), m.shape()); if (!multiplies.FromProto(m.value())) { @@ -1602,8 +1602,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } if (IsPad(*node) && - properties.GetInputProperties(node->name()).size() >= 2) { - const auto& p = properties.GetInputProperties(node->name())[1]; + properties->GetInputProperties(node->name()).size() >= 2) { + const auto& p = properties->GetInputProperties(node->name())[1]; if (TensorShape::IsValid(p.shape()) && p.has_value()) { Tensor paddings(p.dtype(), p.shape()); if (!paddings.FromProto(p.value())) { @@ -1625,12 +1625,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } if (use_shape_info && IsSqueeze(*node) && - !properties.GetInputProperties(node->name()).empty()) { + !properties->GetInputProperties(node->name()).empty()) { // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's // error to squeeze a dimension that is not 1, so we only need to check // whether the input has > 1 size for each dimension. const auto& shape = - properties.GetInputProperties(node->name())[0].shape(); + properties->GetInputProperties(node->name())[0].shape(); // The node is replaceable iff // unknown_rank == false && (dim_size == 0 || all dims have size > 1) bool replaceable = !shape.unknown_rank(); @@ -1642,6 +1642,38 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } } + if (IsPack(*node) && NumNonControlInputs(*node) == 1 && + !OptimizedNodeExists(*node, "_const_axis")) { + // Create constant axis node. + Tensor axis_t(DT_INT32, TensorShape({})); + NodeDef* axis_node = output->add_node(); + axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); + const int axis = node->attr().at("axis").i(); + if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || + !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) + .ok()) { + continue; + } + VLOG(1) << "*** Rewriting trivial Pack node: " << node->DebugString(); + // Add a control dependency to make sure axis_node is in the right frame. + const string ctrl_dep = ConstantFolding::AddControlDependency( + node->input(0), graph_, node_map_.get()); + axis_node->add_input(ctrl_dep); + axis_node->set_device(node->device()); + node->set_op("ExpandDims"); + if (node->attr().count("axis") != 0) { + node->mutable_attr()->erase("axis"); + } + if (node->attr().count("N") != 0) { + node->mutable_attr()->erase("N"); + } + (*node->mutable_attr())["Tdim"].set_type(DT_INT32); + node->add_input(axis_node->name()); + if (node->input_size() > 2) { + node->mutable_input()->SwapElements(1, node->input_size() - 1); + } + } + // Switch(x, x) will always feed false to its false branch and true to // its true branch. By rewriting the graph a bit, we can propagate these // constants down the two output branches, and just use control dependencies @@ -1759,7 +1791,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, graph_modified_ = true; continue; } - if (use_shape_info && IsSimplifiableReshape(*node, properties)) { + if (use_shape_info && IsSimplifiableReshape(*node, *properties)) { DataType output_type = node->attr().at("T").type(); node->set_op("Identity"); node->clear_attr(); @@ -1777,8 +1809,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // Simplify arithmetic operations with ones or zeros. if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_any_div) && - properties.HasInputProperties(node->name()) && - properties.HasOutputProperties(node->name())) { + properties->HasInputProperties(node->name()) && + properties->HasOutputProperties(node->name())) { const NodeDef* x = node_map_->GetNode(node->input(0)); const NodeDef* y = node_map_->GetNode(node->input(1)); if (x == nullptr || y == nullptr) { @@ -1786,12 +1818,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, node->DebugString()); } const TensorShapeProto& output_shape = - properties.GetOutputProperties(node->name())[0].shape(); + properties->GetOutputProperties(node->name())[0].shape(); // Simplify element-wise multiplication by ones or addition/subtraction // of zeros. const TensorShapeProto& y_shape = - properties.GetInputProperties(node->name())[1].shape(); + properties->GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); const bool x_is_one = IsOnes(*x); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); @@ -1818,7 +1850,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } const TensorShapeProto& x_shape = - properties.GetInputProperties(node->name())[0].shape(); + properties->GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = IsOnes(*y); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); @@ -2139,7 +2171,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, } TF_RETURN_IF_ERROR(FoldGraph(output)); node_map_.reset(new NodeMap(output)); - TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info)); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 2fd59c7..13ecfcd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer { bool IsSimplifiableReduction(const NodeDef& node) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph(GraphDef* output, const GraphProperties& properties, + Status SimplifyGraph(GraphDef* output, GraphProperties* properties, bool use_shape_info); Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index f421a59..724fb84 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -1930,6 +1930,48 @@ TEST_F(ConstantFoldingTest, IdenticalN) { EXPECT_EQ("^id_n", output.node(7).input(2)); } +TEST_F(ConstantFoldingTest, TrivialPack) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output x = + ops::RandomNormal(scope.WithOpName("x"), {2, 2}, DataType::DT_FLOAT); + Output y = ops::Const(scope.WithOpName("y"), {2.0f}, {}); + auto stack = + ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x}, + ops::Stack::Axis(1)); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("stack"); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + EXPECT_EQ(5, output.node_size()); + for (const auto& node : output.node()) { + if (node.name() == "stack") { + EXPECT_EQ("stack", node.name()); + EXPECT_EQ("ExpandDims", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1)); + EXPECT_EQ("^y", node.input(2)); + } else if (node.name() == "ConstantFolding/stack_const_axis") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^x", node.input(0)); + } + } + + std::vector fetch = {"stack"}; + auto tensors_expected = EvaluateNodes(item.graph, fetch); + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape()); +} + } // namespace } // namespace grappler } // namespace tensorflow -- 2.7.4