From 53868bfd9705da3fc15b59ab02db39b652686b13 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 6 Apr 2018 09:45:01 -0700 Subject: [PATCH] Materialize tensor array sizes whenever possible PiperOrigin-RevId: 191900015 --- .../core/grappler/optimizers/constant_folding.cc | 33 ++++++++++++++++- .../grappler/optimizers/constant_folding_test.cc | 42 ++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 2f1b9e4..b2a1ce6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -298,7 +298,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { for (int node_idx = 0; node_idx < node_count; ++node_idx) { NodeDef* node = graph_->mutable_node(node_idx); const string op = node->op(); - if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { + if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" && + op != "TensorArraySizeV3") { continue; } @@ -349,6 +350,36 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { continue; } + if (op == "TensorArraySizeV3") { + const NodeDef* array = node_map_->GetNode(node->input(0)); + if (array->attr().count("dynamic_size") != 0 && + array->attr().at("dynamic_size").b()) { + continue; + } + const NodeDef* array_size = node_map_->GetNode(array->input(0)); + if (IsReallyConstant(*array_size)) { + // Don't materialize 0 sizes to avoid triggering incorrect static + // checks. A 0 sized array that can't grow isn't useful anyway. + const TensorProto& raw_val = array_size->attr().at("value").tensor(); + if (raw_val.dtype() != DT_INT32) { + continue; + } + Tensor value(raw_val.dtype(), raw_val.tensor_shape()); + if (!value.FromProto(raw_val)) { + continue; + } + if (value.flat()(0) == 0) { + continue; + } + node->set_op("Const"); + *node->mutable_attr() = array_size->attr(); + node->set_input(0, AsControlDependency(NodeName(node->input(0)))); + node->set_input(1, AddControlDependency(NodeName(node->input(1)), + graph_, node_map_.get())); + } + continue; + } + // Handle ShapeN materialization case. // It's possible that not all input tensors have known shapes. CHECK_EQ(op, "ShapeN"); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 71ee81d..08c9268 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2402,6 +2402,48 @@ TEST_F(ConstantFoldingTest, Enter) { } } +TEST_F(ConstantFoldingTest, TensorArraySize) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({})); + auto dynamic_array = + ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT, + ops::TensorArray::DynamicSize(true)); + auto static_array = + ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT, + ops::TensorArray::DynamicSize(false)); + auto dynamic_sz = ops::TensorArraySize( + scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow); + auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"), + static_array.handle, static_array.flow); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + auto tensors_expected = + EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"}); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + // Run the optimizer twice to make sure the rewrite is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(5, output.node_size()); + EXPECT_EQ("dynamic_sz", output.node(3).name()); + EXPECT_EQ("TensorArraySizeV3", output.node(3).op()); + EXPECT_EQ("static_sz", output.node(4).name()); + EXPECT_EQ("Const", output.node(4).op()); + + auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"}); + EXPECT_EQ(2, tensors_expected.size()); + EXPECT_EQ(2, tensors_actual.size()); + test::ExpectTensorEqual(tensors_expected[0], tensors_actual[0]); + test::ExpectTensorEqual(tensors_expected[1], tensors_actual[1]); +} + } // namespace } // namespace grappler } // namespace tensorflow -- 2.7.4