EXPECT_EQ(3, found);
}
+TEST_F(ConstantFoldingTest, LargeConstant) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ // Generate a 4k by 4k constant matrix.
+ Output mat_diag =
+ ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
+ Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
+ Output out = ops::Identity(scope.WithOpName("out"), mat);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch.push_back("out");
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // Make sure the diag node hasn't been folded, since it would use too much
+ // memory to encode the corresponding constant.
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "out") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("mat", node.input(0));
+ ++found;
+ } else if (node.name() == "mat") {
+ EXPECT_EQ("Diag", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("mat_diag", node.input(0));
+ ++found;
+ }
+ }
+ EXPECT_EQ(2, found);
+
+ EXPECT_GT(1024 * 1024, output.ByteSizeLong());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow