From 66788c60d65564775bcbcf4dc1734157228dbdba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 8 Mar 2018 14:57:22 -0800 Subject: [PATCH] Fix bug in updating NodeMap when materializing shapes from ShapeN. Fix a similar bug in MaybeRemoveControlInput. Improve error message in dependency optimizer, so we can tell if the problem is in dependency optimizer itself or upstream of it. PiperOrigin-RevId: 188394863 --- .../core/grappler/optimizers/constant_folding.cc | 52 +++++++++++++++------- .../grappler/optimizers/constant_folding_test.cc | 50 +++++++++++++++++++++ .../grappler/optimizers/dependency_optimizer.cc | 4 +- 3 files changed, 90 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7780414..31dc1b7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -140,20 +140,20 @@ bool AllValuesAre(const TensorProto& tensor, const T& value) { // Add new_input as a control input to node if it does not already depend on it. // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and // clean up code that should be using them. -bool MaybeAddControlInput(const string& new_input, NodeDef* node, +bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node, GraphDef* graph, NodeMap* node_map) { bool already_exists = false; for (const string& input : node->input()) { - if (input == new_input || AsControlDependency(input) == new_input) { + if (input == ctrl_input || AsControlDependency(input) == ctrl_input) { already_exists = true; break; } } if (!already_exists) { const string ctrl_dep = - ConstantFolding::AddControlDependency(new_input, graph, node_map); + ConstantFolding::AddControlDependency(ctrl_input, graph, node_map); node->add_input(ctrl_dep); - node_map->AddOutput(NodeName(new_input), node->name()); + node_map->AddOutput(NodeName(ctrl_input), node->name()); } return !already_exists; } @@ -161,16 +161,27 @@ bool MaybeAddControlInput(const string& new_input, NodeDef* node, // Remove old_input as a control input to node. bool MaybeRemoveControlInput(const string& old_input, NodeDef* node, GraphDef* graph, NodeMap* node_map) { + bool removed_input = false; + bool update_node_map = true; + const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input)); for (int i = 0; i < node->input_size(); ++i) { const string& input = node->input(i); - if (IsControlInput(input) && AsControlDependency(old_input) == input) { - node->mutable_input()->SwapElements(i, node->input_size() - 1); - node->mutable_input()->RemoveLast(); - node_map->RemoveOutput(NodeName(old_input), node->name()); - return true; + if (old_input_ctrl_dep == input) { + if (IsControlInput(input)) { + node->mutable_input()->SwapElements(i, node->input_size() - 1); + node->mutable_input()->RemoveLast(); + removed_input = true; + } else { + // There is a non-control input from the same node. + // Don't remove the output from the NodeMap. + update_node_map = false; + } } } - return false; + if (update_node_map) { + node_map->RemoveOutput(NodeName(old_input), node->name()); + } + return removed_input; } } // namespace @@ -353,7 +364,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { node_map_->AddOutput(NodeName(ctrl_dep), node->name()); } else { auto outputs = node_map_->GetOutputs(node->name()); - for (const auto& output : outputs) { + for (NodeDef* output : outputs) { for (int k = 0; k < output->input_size(); ++k) { int port; string node_name = ParseNodeName(output->input(k), &port); @@ -378,11 +389,22 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { *added_node->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); } - node_map_->UpdateInput(output->name(), - NodeName(output->input(k)), const_name); *output->mutable_input(k) = const_name; + node_map_->AddOutput(const_name, output->name()); } } + bool remove_output = true; + for (int k = 0; k < output->input_size(); ++k) { + int port; + string node_name = ParseNodeName(output->input(k), &port); + if (node_name == node->name()) { + remove_output = false; + break; + } + } + if (remove_output) { + node_map_->RemoveOutput(node->name(), output->name()); + } } } } @@ -1051,7 +1073,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { node_map_->AddOutput(node->name(), const_index->name()); auto outputs = node_map_->GetOutputs(node->name()); - for (auto& output : outputs) { + for (NodeDef* output : outputs) { for (int i = 0; i < output->input_size(); i++) { int port; string node_name = ParseNodeName(output->input(i), &port); @@ -1142,7 +1164,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { if (const_nodes.size() > 1) { auto outputs = node_map_->GetOutputs(node->name()); - for (const auto& output : outputs) { + for (NodeDef* output : outputs) { for (int i = 0; i < output->input_size(); i++) { int port; string node_name = ParseNodeName(output->input(i), &port); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 29dc93c..4b97708 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -947,6 +947,56 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) { EXPECT_EQ(9, found); } +TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT); + Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT); + auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2}); + auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]}); + Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]); + Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("ia"); + item.fetch.push_back("ib"); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found = 0; + for (const auto& node : output.node()) { + EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst), + node.name()); + if (node.name() == "s") { + ++found; + EXPECT_EQ("ShapeN", node.op()); + EXPECT_EQ("v1", node.input(0)); + EXPECT_EQ("v2", node.input(1)); + } + if (node.name() == "id_n") { + ++found; + EXPECT_EQ("IdentityN", node.op()); + EXPECT_EQ("s", node.input(0)); + EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst), + node.input(1)); + } + if (node.name() == "ia") { + ++found; + EXPECT_EQ("id_n", node.input(0)); + } + if (node.name() == "ib") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^s", node.input(0)); + EXPECT_EQ("^id_n", node.input(1)); + } + } + EXPECT_EQ(4, found); +} + TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index bb4b916..a5b2572 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -576,7 +576,9 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Remove redundant control dependencies. TF_RETURN_IF_ERROR(TransitiveReduction()); } else { - LOG(ERROR) << topo_sort_status.error_message(); + LOG(ERROR) << "Iteration = " << iteration + << ", topological sort failed with message: " + << topo_sort_status.error_message(); } // Turn nodes with only control outputs into NoOps, prune NoOp and Identity // nodes. -- 2.7.4