Added a regression test to make sure we deal with large constants properly
authorBenoit Steiner <bsteiner@google.com>
Thu, 22 Feb 2018 18:44:19 +0000 (10:44 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 18:48:14 +0000 (10:48 -0800)
PiperOrigin-RevId: 186639709

tensorflow/core/grappler/optimizers/constant_folding_test.cc

index 3afc176..2048692 100644 (file)
@@ -1456,6 +1456,44 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
   EXPECT_EQ(3, found);
 }
 
+TEST_F(ConstantFoldingTest, LargeConstant) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+  // Generate a 4k by 4k constant matrix.
+  Output mat_diag =
+      ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4}));
+  Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag);
+  Output out = ops::Identity(scope.WithOpName("out"), mat);
+
+  GrapplerItem item;
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+  item.fetch.push_back("out");
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  // Make sure the diag node hasn't been folded, since it would use too much
+  // memory to encode the corresponding constant.
+  int found = 0;
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "out") {
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("mat", node.input(0));
+      ++found;
+    } else if (node.name() == "mat") {
+      EXPECT_EQ("Diag", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("mat_diag", node.input(0));
+      ++found;
+    }
+  }
+  EXPECT_EQ(2, found);
+
+  EXPECT_GT(1024 * 1024, output.ByteSizeLong());
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow