Fix bug in updating NodeMap when materializing shapes from ShapeN.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Mar 2018 22:57:22 +0000 (14:57 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 23:04:29 +0000 (15:04 -0800)
Fix a similar bug in MaybeRemoveControlInput.
Improve error message in dependency optimizer, so we can tell if the problem is in dependency optimizer itself or upstream of it.

PiperOrigin-RevId: 188394863

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding_test.cc
tensorflow/core/grappler/optimizers/dependency_optimizer.cc

index 7780414..31dc1b7 100644 (file)
@@ -140,20 +140,20 @@ bool AllValuesAre(const TensorProto& tensor, const T& value) {
 // Add new_input as a control input to node if it does not already depend on it.
 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
 // clean up code that should be using them.
-bool MaybeAddControlInput(const string& new_input, NodeDef* node,
+bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
                           GraphDef* graph, NodeMap* node_map) {
   bool already_exists = false;
   for (const string& input : node->input()) {
-    if (input == new_input || AsControlDependency(input) == new_input) {
+    if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
       already_exists = true;
       break;
     }
   }
   if (!already_exists) {
     const string ctrl_dep =
-        ConstantFolding::AddControlDependency(new_input, graph, node_map);
+        ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
     node->add_input(ctrl_dep);
-    node_map->AddOutput(NodeName(new_input), node->name());
+    node_map->AddOutput(NodeName(ctrl_input), node->name());
   }
   return !already_exists;
 }
@@ -161,16 +161,27 @@ bool MaybeAddControlInput(const string& new_input, NodeDef* node,
 // Remove old_input as a control input to node.
 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
                              GraphDef* graph, NodeMap* node_map) {
+  bool removed_input = false;
+  bool update_node_map = true;
+  const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
   for (int i = 0; i < node->input_size(); ++i) {
     const string& input = node->input(i);
-    if (IsControlInput(input) && AsControlDependency(old_input) == input) {
-      node->mutable_input()->SwapElements(i, node->input_size() - 1);
-      node->mutable_input()->RemoveLast();
-      node_map->RemoveOutput(NodeName(old_input), node->name());
-      return true;
+    if (old_input_ctrl_dep == input) {
+      if (IsControlInput(input)) {
+        node->mutable_input()->SwapElements(i, node->input_size() - 1);
+        node->mutable_input()->RemoveLast();
+        removed_input = true;
+      } else {
+        // There is a non-control input from the same node.
+        // Don't remove the output from the NodeMap.
+        update_node_map = false;
+      }
     }
   }
-  return false;
+  if (update_node_map) {
+    node_map->RemoveOutput(NodeName(old_input), node->name());
+  }
+  return removed_input;
 }
 
 }  // namespace
@@ -353,7 +364,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
           node_map_->AddOutput(NodeName(ctrl_dep), node->name());
         } else {
           auto outputs = node_map_->GetOutputs(node->name());
-          for (const auto& output : outputs) {
+          for (NodeDef* output : outputs) {
             for (int k = 0; k < output->input_size(); ++k) {
               int port;
               string node_name = ParseNodeName(output->input(k), &port);
@@ -378,11 +389,22 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
                   *added_node->add_input() = ctrl_dep;
                   node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
                 }
-                node_map_->UpdateInput(output->name(),
-                                       NodeName(output->input(k)), const_name);
                 *output->mutable_input(k) = const_name;
+                node_map_->AddOutput(const_name, output->name());
               }
             }
+            bool remove_output = true;
+            for (int k = 0; k < output->input_size(); ++k) {
+              int port;
+              string node_name = ParseNodeName(output->input(k), &port);
+              if (node_name == node->name()) {
+                remove_output = false;
+                break;
+              }
+            }
+            if (remove_output) {
+              node_map_->RemoveOutput(node->name(), output->name());
+            }
           }
         }
       }
@@ -1051,7 +1073,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
       node_map_->AddOutput(node->name(), const_index->name());
 
       auto outputs = node_map_->GetOutputs(node->name());
-      for (auto& output : outputs) {
+      for (NodeDef* output : outputs) {
         for (int i = 0; i < output->input_size(); i++) {
           int port;
           string node_name = ParseNodeName(output->input(i), &port);
@@ -1142,7 +1164,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
 
   if (const_nodes.size() > 1) {
     auto outputs = node_map_->GetOutputs(node->name());
-    for (const auto& output : outputs) {
+    for (NodeDef* output : outputs) {
       for (int i = 0; i < output->input_size(); i++) {
         int port;
         string node_name = ParseNodeName(output->input(i), &port);
index 29dc93c..4b97708 100644 (file)
@@ -947,6 +947,56 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
   EXPECT_EQ(9, found);
 }
 
+TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
+  tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+  Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT);
+  Output v2 = ops::Variable(scope.WithOpName("v2"), {4, 6}, DT_FLOAT);
+  auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2});
+  auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {s[0], s[1]});
+  Output ia = ops::Identity(scope.WithOpName("ia"), id_n[0]);
+  Output ib = ops::Identity(scope.WithOpName("ib"), id_n[1]);
+
+  GrapplerItem item;
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+  item.fetch.push_back("ia");
+  item.fetch.push_back("ib");
+
+  ConstantFolding fold(nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  int found = 0;
+  for (const auto& node : output.node()) {
+    EXPECT_NE(AddPrefixToNodeName("s-matshapes-0", kConstantFoldingConst),
+              node.name());
+    if (node.name() == "s") {
+      ++found;
+      EXPECT_EQ("ShapeN", node.op());
+      EXPECT_EQ("v1", node.input(0));
+      EXPECT_EQ("v2", node.input(1));
+    }
+    if (node.name() == "id_n") {
+      ++found;
+      EXPECT_EQ("IdentityN", node.op());
+      EXPECT_EQ("s", node.input(0));
+      EXPECT_EQ(AddPrefixToNodeName("s-matshapes-1", kConstantFoldingConst),
+                node.input(1));
+    }
+    if (node.name() == "ia") {
+      ++found;
+      EXPECT_EQ("id_n", node.input(0));
+    }
+    if (node.name() == "ib") {
+      ++found;
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ("^s", node.input(0));
+      EXPECT_EQ("^id_n", node.input(1));
+    }
+  }
+  EXPECT_EQ(4, found);
+}
+
 TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) {
   tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
   ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);
index bb4b916..a5b2572 100644 (file)
@@ -576,7 +576,9 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
       // Remove redundant control dependencies.
       TF_RETURN_IF_ERROR(TransitiveReduction());
     } else {
-      LOG(ERROR) << topo_sort_status.error_message();
+      LOG(ERROR) << "Iteration = " << iteration
+                 << ", topological sort failed with message: "
+                 << topo_sort_status.error_message();
     }
     // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
     // nodes.