// Add new_input as a control input to node if it does not already depend on it.
// TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
// clean up code that should be using them.
-bool MaybeAddControlInput(const string& new_input, NodeDef* node,
+bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
bool already_exists = false;
for (const string& input : node->input()) {
- if (input == new_input || AsControlDependency(input) == new_input) {
+ if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
already_exists = true;
break;
}
}
if (!already_exists) {
const string ctrl_dep =
- ConstantFolding::AddControlDependency(new_input, graph, node_map);
+ ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
node->add_input(ctrl_dep);
- node_map->AddOutput(NodeName(new_input), node->name());
+ node_map->AddOutput(NodeName(ctrl_input), node->name());
}
return !already_exists;
}
// Remove old_input as a control input to node.
bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
+ bool removed_input = false;
+ bool update_node_map = true;
+ const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
- if (IsControlInput(input) && AsControlDependency(old_input) == input) {
- node->mutable_input()->SwapElements(i, node->input_size() - 1);
- node->mutable_input()->RemoveLast();
- node_map->RemoveOutput(NodeName(old_input), node->name());
- return true;
+ if (old_input_ctrl_dep == input) {
+ if (IsControlInput(input)) {
+ node->mutable_input()->SwapElements(i, node->input_size() - 1);
+ node->mutable_input()->RemoveLast();
+ removed_input = true;
+ } else {
+ // There is a non-control input from the same node.
+ // Don't remove the output from the NodeMap.
+ update_node_map = false;
+ }
}
}
- return false;
+ if (update_node_map) {
+ node_map->RemoveOutput(NodeName(old_input), node->name());
+ }
+ return removed_input;
}
} // namespace
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
} else {
auto outputs = node_map_->GetOutputs(node->name());
- for (const auto& output : outputs) {
+ for (NodeDef* output : outputs) {
for (int k = 0; k < output->input_size(); ++k) {
int port;
string node_name = ParseNodeName(output->input(k), &port);
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
- node_map_->UpdateInput(output->name(),
- NodeName(output->input(k)), const_name);
*output->mutable_input(k) = const_name;
+ node_map_->AddOutput(const_name, output->name());
}
}
+ bool remove_output = true;
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == node->name()) {
+ remove_output = false;
+ break;
+ }
+ }
+ if (remove_output) {
+ node_map_->RemoveOutput(node->name(), output->name());
+ }
}
}
}
node_map_->AddOutput(node->name(), const_index->name());
auto outputs = node_map_->GetOutputs(node->name());
- for (auto& output : outputs) {
+ for (NodeDef* output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int port;
string node_name = ParseNodeName(output->input(i), &port);
if (const_nodes.size() > 1) {
auto outputs = node_map_->GetOutputs(node->name());
- for (const auto& output : outputs) {
+ for (NodeDef* output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int port;
string node_name = ParseNodeName(output->input(i), &port);
EXPECT_EQ(9, found);
}
+TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
+ Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
+ auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
+ auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
+ Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
+ Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch.push_back("ia");
+ item.fetch.push_back("ib");
+
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ int found = 0;
+ for (const auto& node : output.node()) {
+ EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
+ node.name());
+ if (node.name() == "s") {
+ ++found;
+ EXPECT_EQ("ShapeN", node.op());
+ EXPECT_EQ("v1", node.input(0));
+ EXPECT_EQ("v2", node.input(1));
+ }
+ if (node.name() == "id_n") {
+ ++found;
+ EXPECT_EQ("IdentityN", node.op());
+ EXPECT_EQ("s", node.input(0));
+ EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
+ node.input(1));
+ }
+ if (node.name() == "ia") {
+ ++found;
+ EXPECT_EQ("id_n", node.input(0));
+ }
+ if (node.name() == "ib") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^s", node.input(0));
+ EXPECT_EQ("^id_n", node.input(1));
+ }
+ }
+ EXPECT_EQ(4, found);
+}
+
TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);