for (int node_idx = 0; node_idx < node_count; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
- if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
+ if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
+ op != "TensorArraySizeV3") {
continue;
}
continue;
}
+ if (op == "TensorArraySizeV3") {
+ const NodeDef* array = node_map_->GetNode(node->input(0));
+ if (array->attr().count("dynamic_size") != 0 &&
+ array->attr().at("dynamic_size").b()) {
+ continue;
+ }
+ const NodeDef* array_size = node_map_->GetNode(array->input(0));
+ if (IsReallyConstant(*array_size)) {
+ // Don't materialize 0 sizes to avoid triggering incorrect static
+ // checks. A 0 sized array that can't grow isn't useful anyway.
+ const TensorProto& raw_val = array_size->attr().at("value").tensor();
+ if (raw_val.dtype() != DT_INT32) {
+ continue;
+ }
+ Tensor value(raw_val.dtype(), raw_val.tensor_shape());
+ if (!value.FromProto(raw_val)) {
+ continue;
+ }
+ if (value.flat<int32>()(0) == 0) {
+ continue;
+ }
+ node->set_op("Const");
+ *node->mutable_attr() = array_size->attr();
+ node->set_input(0, AsControlDependency(NodeName(node->input(0))));
+ node->set_input(1, AddControlDependency(NodeName(node->input(1)),
+ graph_, node_map_.get()));
+ }
+ continue;
+ }
+
// Handle ShapeN materialization case.
// It's possible that not all input tensors have known shapes.
CHECK_EQ(op, "ShapeN");
}
}
+TEST_F(ConstantFoldingTest, TensorArraySize) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+ auto dynamic_array =
+ ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
+ ops::TensorArray::DynamicSize(true));
+ auto static_array =
+ ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
+ ops::TensorArray::DynamicSize(false));
+ auto dynamic_sz = ops::TensorArraySize(
+ scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
+ auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
+ static_array.handle, static_array.flow);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ auto tensors_expected =
+ EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ("dynamic_sz", output.node(3).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
+ EXPECT_EQ("static_sz", output.node(4).name());
+ EXPECT_EQ("Const", output.node(4).op());
+
+ auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors_actual.size());
+ test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
+ test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow