Properly parse input strings in the dependency optimizer
authorBenoit Steiner <bsteiner@google.com>
Wed, 7 Mar 2018 18:58:44 +0000 (10:58 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 7 Mar 2018 19:03:17 +0000 (11:03 -0800)
PiperOrigin-RevId: 188201284

tensorflow/core/grappler/optimizers/dependency_optimizer.cc
tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc

index b47cba5..bb4b916 100644 (file)
@@ -346,16 +346,23 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
           CHECK(!IsControlInput(input_to_forward));
           for (int j = 0; j < consumer->input_size(); ++j) {
             const string& old_input = consumer->input(j);
-            if (old_input == node_name) {
-              new_input = input_to_forward;
-              node_map_->UpdateInput(consumer->name(), old_input, new_input);
-              consumer->set_input(j, new_input);
-              found_input = true;
-            } else if (old_input == AsControlDependency(NodeName(node_name))) {
-              new_input = AsControlDependency(NodeName(input_to_forward));
-              node_map_->UpdateInput(consumer->name(), old_input, new_input);
-              consumer->set_input(j, new_input);
-              found_input = true;
+            int old_input_pos;
+            string old_input_node_name =
+                ParseNodeName(old_input, &old_input_pos);
+            if (old_input_node_name == node_name) {
+              if (old_input_pos >= 0) {
+                // Regular input
+                new_input = input_to_forward;
+                node_map_->UpdateInput(consumer->name(), old_input, new_input);
+                consumer->set_input(j, new_input);
+                found_input = true;
+              } else {
+                // Control dependency
+                new_input = AsControlDependency(NodeName(input_to_forward));
+                node_map_->UpdateInput(consumer->name(), old_input, new_input);
+                consumer->set_input(j, new_input);
+                found_input = true;
+              }
             }
           }
           CHECK(found_input);
index 33d6b99..08659cb 100644 (file)
@@ -515,6 +515,39 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
   }
 }
 
+TEST_F(DependencyOptimizerTest, IdentityInputs) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+  Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
+  Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
+  auto s = ops::Switch(scope.WithOpName("s"), x, b);
+
+  // Identity nodes to be removed.
+  auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
+  auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
+
+  // Output
+  Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
+  Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
+
+  GrapplerItem item;
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+  item.fetch = {"out1", "out2"};
+
+  DependencyOptimizer optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(6, output.node_size());
+  EXPECT_EQ("out1", output.node(4).name());
+  EXPECT_EQ(1, output.node(4).input_size());
+  EXPECT_EQ("s", output.node(4).input(0));
+
+  EXPECT_EQ("out2", output.node(5).name());
+  EXPECT_EQ(1, output.node(5).input_size());
+  EXPECT_EQ("s:1", output.node(5).input(0));
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow