CHECK(!IsControlInput(input_to_forward));
for (int j = 0; j < consumer->input_size(); ++j) {
const string& old_input = consumer->input(j);
- if (old_input == node_name) {
- new_input = input_to_forward;
- node_map_->UpdateInput(consumer->name(), old_input, new_input);
- consumer->set_input(j, new_input);
- found_input = true;
- } else if (old_input == AsControlDependency(NodeName(node_name))) {
- new_input = AsControlDependency(NodeName(input_to_forward));
- node_map_->UpdateInput(consumer->name(), old_input, new_input);
- consumer->set_input(j, new_input);
- found_input = true;
+ int old_input_pos;
+ string old_input_node_name =
+ ParseNodeName(old_input, &old_input_pos);
+ if (old_input_node_name == node_name) {
+ if (old_input_pos >= 0) {
+ // Regular input
+ new_input = input_to_forward;
+ node_map_->UpdateInput(consumer->name(), old_input, new_input);
+ consumer->set_input(j, new_input);
+ found_input = true;
+ } else {
+ // Control dependency
+ new_input = AsControlDependency(NodeName(input_to_forward));
+ node_map_->UpdateInput(consumer->name(), old_input, new_input);
+ consumer->set_input(j, new_input);
+ found_input = true;
+ }
}
}
CHECK(found_input);
}
}
+TEST_F(DependencyOptimizerTest, IdentityInputs) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
+ Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
+ auto s = ops::Switch(scope.WithOpName("s"), x, b);
+
+ // Identity nodes to be removed.
+ auto id_f = ops::Identity(scope.WithOpName("id_f"), s.output_false);
+ auto id_t = ops::Identity(scope.WithOpName("id_t"), s.output_true);
+
+ // Output
+ Output out1 = ops::Identity(scope.WithOpName("out1"), id_f);
+ Output out2 = ops::Identity(scope.WithOpName("out2"), id_t);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch = {"out1", "out2"};
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(6, output.node_size());
+ EXPECT_EQ("out1", output.node(4).name());
+ EXPECT_EQ(1, output.node(4).input_size());
+ EXPECT_EQ("s", output.node(4).input(0));
+
+ EXPECT_EQ("out2", output.node(5).name());
+ EXPECT_EQ(1, output.node(5).input_size());
+ EXPECT_EQ("s:1", output.node(5).input(0));
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow