Add a regression test for virtual_scheduler.
authorMax Galkin <maxgalkin@google.com>
Fri, 23 Feb 2018 00:13:00 +0000 (16:13 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 00:16:06 +0000 (16:16 -0800)
PiperOrigin-RevId: 186691392

tensorflow/core/grappler/costs/virtual_scheduler_test.cc

index d44b83d..f9154e4 100644 (file)
@@ -205,6 +205,25 @@ class VirtualSchedulerTest : public ::testing::Test {
     dependency_["out"] = {"x", "y", "z", "w"};
   }
 
+  // Graph with some placeholder feed nodes that are not in the fetch fan-in.
+  void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() {
+    Scope s = Scope::NewRootScope().WithDevice(kCPU0);
+    auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT);
+    auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT);
+
+    GraphDef def;
+    TF_CHECK_OK(s.ToGraphDef(&def));
+
+    grappler_item_.reset(new GrapplerItem);
+    grappler_item_->id = "test_extra_placeholders";
+    grappler_item_->graph = def;
+    grappler_item_->fetch = {"x"};
+
+    // Grappler Item Builder puts all placeholder nodes into the feed
+    // list by default.
+    grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}};
+  }
+
   // NoOp that takes 7 NoOps as control dependency.
   void CreateGrapplerItemWithControlDependency() {
     Scope s = Scope::NewRootScope().WithDevice(kCPU0);
@@ -1757,6 +1776,16 @@ TEST_F(VirtualSchedulerTest, MemoryUsage) {
                               cpu_state.mem_usage_snapshot_at_peak);
 }
 
+TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) {
+  CreateGrapplerItemWithUnnecessaryPlaceholderNodes();
+  InitScheduler();
+
+  // Test that scheduler can run graphs with extra unnecessary feed nodes.
+  auto ops_executed = RunScheduler("");
+  ASSERT_EQ(1, ops_executed.size());
+  ASSERT_EQ(ops_executed.count("x"), 1);
+}
+
 TEST_F(VirtualSchedulerTest, ControlDependency) {
   // Init.
   CreateGrapplerItemWithControlDependency();