Preserving order when removing nodes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Feb 2018 19:20:39 +0000 (11:20 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Feb 2018 19:28:11 +0000 (11:28 -0800)
PiperOrigin-RevId: 185023366

tensorflow/tools/graph_transforms/sparsify_gather.cc

index 214ec721e2c9b8bdd761a1cb7a92a74f4a2a42a0..701e350fc39d083665f5420e6b73510c182e12ce 100644 (file)
@@ -212,6 +212,14 @@ Status RemoveInputAtIndex(NodeDef* n, int index) {
   return Status::OK();
 }
 
+Status RemoveNodeAtIndex(GraphDef* g, int index) {
+  for (int i = index; i < g->node_size() - 1; i++) {
+    g->mutable_node()->SwapElements(i, i + 1);
+  }
+  g->mutable_node()->RemoveLast();
+  return Status::OK();
+}
+
 Status SparsifyGatherInternal(
     const GraphDef& input_graph_def,
     const std::unique_ptr<std::unordered_map<string, string> >&
@@ -493,9 +501,7 @@ Status SparsifyGatherInternal(
               removed_node_names.push_back(parsed_input);
             }
           }
-          replaced_graph_def.mutable_node()->SwapElements(
-              i, replaced_graph_def.node_size() - 1);
-          replaced_graph_def.mutable_node()->RemoveLast();
+          TF_RETURN_IF_ERROR(RemoveNodeAtIndex(&replaced_graph_def, i));
           continue;
         }
         int j = 0;