Supporting new saving op structure.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Feb 2018 03:50:34 +0000 (19:50 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 05:44:13 +0000 (21:44 -0800)
PiperOrigin-RevId: 184233513

tensorflow/tools/graph_transforms/sparsify_gather.cc
tensorflow/tools/graph_transforms/sparsify_gather_test.cc

index 9c583d8..214ec72 100644 (file)
@@ -86,8 +86,17 @@ void CreateConstNode(const Tensor& tensor, const string& name,
   SetNodeTensorAttr<float>("value", tensor, node_def);
 }
 
+string GetMonolithicTensorKey(const string& tensor_slice_name) {
+  std::vector<string> names = Split(tensor_slice_name, "/");
+  if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
+    CHECK_GE(names.size(), 2);
+    names.pop_back();
+  }
+  return Join(names, "/");
+}
+
 Status ObtainTensorSlice(const GraphDef& input_graph_def,
-                         const string& tensor_name,
+                         const string& target_name,
                          string* shape_slice_string) {
   string restore_node_name;
   for (const auto& node : input_graph_def.node()) {
@@ -95,39 +104,53 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
     if (node_name_parts.size() == 2 &&
         StringPiece(node_name_parts[0]).starts_with("save") &&
         StringPiece(node_name_parts[1]).starts_with("Assign") &&
-        node.input(0) == tensor_name) {
+        node.input(0) == target_name) {
       restore_node_name = node.input(1);
       break;
     }
   }
+
+  std::vector<string> restore_node_parts = Split(restore_node_name, ":");
+  CHECK_LE(restore_node_parts.size(), 2);
+  string tensor_names_node;
   string shape_and_slices_node;
   for (const auto& node : input_graph_def.node()) {
-    if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) {
+    if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
+      tensor_names_node = node.input(1);
       shape_and_slices_node = node.input(2);
       break;
     }
   }
+
+  int offset = -1;
+  for (const auto& node : input_graph_def.node()) {
+    if (node.name() == tensor_names_node) {
+      Tensor tensor_names_tensor;
+      TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
+      const auto& tensor_names_value = tensor_names_tensor.flat<string>();
+      for (int i = 0; i < tensor_names_value.size(); i++) {
+        if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
+          offset = i;
+          break;
+        }
+      }
+    }
+  }
+  if (offset == -1) {
+    return errors::Internal("Unable to find RestoreV2 entry for variable: ",
+                            target_name);
+  }
   for (const auto& node : input_graph_def.node()) {
     if (node.name() == shape_and_slices_node) {
       Tensor shape_and_slices_tensor;
       TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
       const auto& shape_and_slices_value =
           shape_and_slices_tensor.flat<string>();
-      *shape_slice_string = shape_and_slices_value(0);
+      *shape_slice_string = shape_and_slices_value(offset);
       return Status::OK();
     }
   }
-  return errors::Internal("Unable to find slice for variable: ", tensor_name);
-}
-
-string GetMonolithicTensorKey(const string& tensor_slice_name) {
-  std::vector<string> names = Split(tensor_slice_name, "/");
-  CHECK_GE(names.size(), 2);
-  CHECK(StringPiece(names[names.size() - 1]).starts_with("part_"));
-
-  // Remove the "part_x" suffix
-  names.pop_back();
-  return Join(names, "/");
+  return errors::Internal("Unable to find slice for variable: ", target_name);
 }
 
 Status ReadTensorFromCheckpoint(
index 203ed3e..d41321c 100644 (file)
@@ -106,11 +106,15 @@ class SparsifyGatherTest : public ::testing::Test {
       NodeDef* save_const_node =
           CreateNode("save/Const", "Const", {}, &graph_def);
 
+      Tensor tensor_names_values(DT_STRING, TensorShape({1}));
+      test::FillValues<string>(&tensor_names_values, {"w"});
       NodeDef* tensor_names_node =
           CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+      SetNodeTensorAttr<string>("value", tensor_names_values,
+                                tensor_names_node);
+
       NodeDef* tensor_shapes_slices_node = CreateNode(
           "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
       Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
       shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
       SetNodeTensorAttr<string>("value", shapes_slices_val,
@@ -310,6 +314,29 @@ class SparsifyGatherTest : public ::testing::Test {
       SetNodeTensorAttr<float>("value", weights, w_node1);
       SetNodeTensorAttr<float>("value", weights, w_node2);
     } else {
+      NodeDef* save_const_node =
+          CreateNode("save/Const", "Const", {}, &graph_def);
+
+      NodeDef* tensor_names_node =
+          CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+      Tensor tensor_names_values(DT_STRING, TensorShape({2}));
+      test::FillValues<string>(&tensor_names_values, {"w1", "w2"});
+      SetNodeTensorAttr<string>("value", tensor_names_values,
+                                tensor_names_node);
+
+      NodeDef* tensor_shapes_slices_node = CreateNode(
+          "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
+      Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
+      shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
+      shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
+      SetNodeTensorAttr<string>("value", shapes_slices_val,
+                                tensor_shapes_slices_node);
+
+      NodeDef* restore_node = CreateNode(
+          "save/RestoreV2", "RestoreV2",
+          {save_const_node, tensor_names_node, tensor_shapes_slices_node},
+          &graph_def);
+
       w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
 
       zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
@@ -321,23 +348,7 @@ class SparsifyGatherTest : public ::testing::Test {
       assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
                                 {w_node1, zeros_node1}, &graph_def);
 
-      NodeDef* save_const_node =
-          CreateNode("save/Const", "Const", {}, &graph_def);
-      NodeDef* tensor_names_node1 =
-          CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
-      NodeDef* tensor_shapes_slices_node1 = CreateNode(
-          "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
-      Tensor shapes_slices_val1(DT_STRING, TensorShape({1}));
-      shapes_slices_val1.flat<string>()(0) = "4 1 0,4:0,1";
-      SetNodeTensorAttr<string>("value", shapes_slices_val1,
-                                tensor_shapes_slices_node1);
-
-      NodeDef* restore_node1 = CreateNode(
-          "save/RestoreV2", "RestoreV2",
-          {save_const_node, tensor_names_node1, tensor_shapes_slices_node1},
-          &graph_def);
-      CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def);
+      CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def);
 
       w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
       zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
@@ -349,21 +360,7 @@ class SparsifyGatherTest : public ::testing::Test {
       assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
                                 {w_node2, zeros_node2}, &graph_def);
 
-      NodeDef* tensor_names_node2 =
-          CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def);
-      NodeDef* tensor_shapes_slices_node2 = CreateNode(
-          "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def);
-
-      Tensor shapes_slices_val2(DT_STRING, TensorShape({1}));
-      shapes_slices_val2.flat<string>()(0) = "4 1 0,4:0,1";
-      SetNodeTensorAttr<string>("value", shapes_slices_val2,
-                                tensor_shapes_slices_node2);
-
-      NodeDef* restore_node2 = CreateNode(
-          "save/RestoreV2_1", "RestoreV2",
-          {save_const_node, tensor_names_node2, tensor_shapes_slices_node2},
-          &graph_def);
-      CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2},
+      CreateNode("save/Assign_1", "Assign", {w_node2, restore_node},
                  &graph_def);
 
       BundleWriter writer(Env::Default(), checkpoint_path);