Materialize tensor array sizes whenever possible
authorBenoit Steiner <bsteiner@google.com>
Fri, 6 Apr 2018 16:45:01 +0000 (09:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 16:47:11 +0000 (09:47 -0700)
PiperOrigin-RevId: 191900015

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 2f1b9e4..b2a1ce6 100644 (file)
@@ -298,7 +298,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
   for (int node_idx = 0; node_idx < node_count; ++node_idx) {
     NodeDef* node = graph_->mutable_node(node_idx);
     const string op = node->op();
-    if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
+    if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
+        op != "TensorArraySizeV3") {
       continue;
     }
 
@@ -349,6 +350,36 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
       continue;
     }
 
+    if (op == "TensorArraySizeV3") {
+      const NodeDef* array = node_map_->GetNode(node->input(0));
+      if (array->attr().count("dynamic_size") != 0 &&
+          array->attr().at("dynamic_size").b()) {
+        continue;
+      }
+      const NodeDef* array_size = node_map_->GetNode(array->input(0));
+      if (IsReallyConstant(*array_size)) {
+        // Don't materialize 0 sizes to avoid triggering incorrect static
+        // checks. A 0 sized array that can't grow isn't useful anyway.
+        const TensorProto& raw_val = array_size->attr().at("value").tensor();
+        if (raw_val.dtype() != DT_INT32) {
+          continue;
+        }
+        Tensor value(raw_val.dtype(), raw_val.tensor_shape());
+        if (!value.FromProto(raw_val)) {
+          continue;
+        }
+        if (value.flat<int32>()(0) == 0) {
+          continue;
+        }
+        node->set_op("Const");
+        *node->mutable_attr() = array_size->attr();
+        node->set_input(0, AsControlDependency(NodeName(node->input(0))));
+        node->set_input(1, AddControlDependency(NodeName(node->input(1)),
+                                                graph_, node_map_.get()));
+      }
+      continue;
+    }
+
     // Handle ShapeN materialization case.
     // It's possible that not all input tensors have known shapes.
     CHECK_EQ(op, "ShapeN");
index 71ee81d..08c9268 100644 (file)
@@ -2402,6 +2402,48 @@ TEST_F(ConstantFoldingTest, Enter) {
   }
 }
 
+TEST_F(ConstantFoldingTest, TensorArraySize) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+  Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+  auto dynamic_array =
+      ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
+                       ops::TensorArray::DynamicSize(true));
+  auto static_array =
+      ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
+                       ops::TensorArray::DynamicSize(false));
+  auto dynamic_sz = ops::TensorArraySize(
+      scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
+  auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
+                                        static_array.handle, static_array.flow);
+
+  GrapplerItem item;
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  auto tensors_expected =
+      EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
+
+  ConstantFolding optimizer(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(5, output.node_size());
+  EXPECT_EQ("dynamic_sz", output.node(3).name());
+  EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
+  EXPECT_EQ("static_sz", output.node(4).name());
+  EXPECT_EQ("Const", output.node(4).op());
+
+  auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
+  EXPECT_EQ(2, tensors_expected.size());
+  EXPECT_EQ(2, tensors_actual.size());
+  test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
+  test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow