Fix FreezeSavedModel to handle traversal of operations with multiple outputs.
authorSuharsh Sivakumar <suharshs@google.com>
Thu, 10 May 2018 00:30:30 +0000 (17:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 10 May 2018 00:33:41 +0000 (17:33 -0700)
PiperOrigin-RevId: 196055377

tensorflow/cc/tools/freeze_saved_model.cc
tensorflow/cc/tools/freeze_saved_model_test.cc

index 4ddddcb..2a859d6 100644 (file)
@@ -71,6 +71,12 @@ void GetNodeNameToNodeDefMap(
   }
 }
 
+// Strips off the tensor part of the tensor_name to get the node_name.
+const string GetNodeNameFromTensorName(const string& tensor_name) {
+  std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
+  return tensor_name_parts[0];
+}
+
 // Gets the set of node names needed by `outputs` and the corresponding set of
 // variable nodes to convert.
 void GetReachableNodesAndVariables(
@@ -83,10 +89,8 @@ void GetReachableNodesAndVariables(
       new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
 
   std::queue<string> nodes_to_visit;
-  for (const string& tensor_name : outputs) {
-    // We need to strip off the tensor part to get the node name.
-    std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
-    nodes_to_visit.push(tensor_name_parts[0]);
+  for (const string& output_tensor_name : outputs) {
+    nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
   }
   // We do a traversal backwards from the outputs specified in the MetaGraphDef.
   while (!nodes_to_visit.empty()) {
@@ -100,8 +104,8 @@ void GetReachableNodesAndVariables(
     if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
       variable_node_names->insert(node->name());
     }
-    for (const string& input : node->input()) {
-      nodes_to_visit.push(input);
+    for (const string& input_tensor_name : node->input()) {
+      nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
     }
   }
 }
index cd35fd3..e265a68 100644 (file)
@@ -351,6 +351,31 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) {
   GraphDefEqual(frozen_graph_def, graph_def);
 }
 
+TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) {
+  // Tensors from operations with multiple outputs get tensor suffixes when used
+  // in input fields of following nodes, i.e. split:0, split:1.
+  // Test that we traverse those correctly.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  Scope scope = Scope::NewRootScope();
+  Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2});
+  Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+  OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output;
+  Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+  Output c = ops::Mul(scope.WithOpName("c"), split[1], b);
+  TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+  TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+                                                        &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  GraphDefEqual(frozen_graph_def, graph_def);
+}
+
 TEST_F(FreezeTest, GraphDefWithoutDependentVariables) {
   TestFreezeGraphWithoutDependentVariables(false);
 }