From: A. Unique TensorFlower Date: Fri, 2 Feb 2018 03:50:34 +0000 (-0800) Subject: Supporting new saving op structure. X-Git-Tag: upstream/v1.7.0~31^2~1072 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c39d141948174b94213848d9b95541ee09af5e53;p=platform%2Fupstream%2Ftensorflow.git Supporting new saving op structure. PiperOrigin-RevId: 184233513 --- diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc index 9c583d8..214ec72 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -86,8 +86,17 @@ void CreateConstNode(const Tensor& tensor, const string& name, SetNodeTensorAttr("value", tensor, node_def); } +string GetMonolithicTensorKey(const string& tensor_slice_name) { + std::vector names = Split(tensor_slice_name, "/"); + if (StringPiece(names[names.size() - 1]).starts_with("part_")) { + CHECK_GE(names.size(), 2); + names.pop_back(); + } + return Join(names, "/"); +} + Status ObtainTensorSlice(const GraphDef& input_graph_def, - const string& tensor_name, + const string& target_name, string* shape_slice_string) { string restore_node_name; for (const auto& node : input_graph_def.node()) { @@ -95,39 +104,53 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def, if (node_name_parts.size() == 2 && StringPiece(node_name_parts[0]).starts_with("save") && StringPiece(node_name_parts[1]).starts_with("Assign") && - node.input(0) == tensor_name) { + node.input(0) == target_name) { restore_node_name = node.input(1); break; } } + + std::vector restore_node_parts = Split(restore_node_name, ":"); + CHECK_LE(restore_node_parts.size(), 2); + string tensor_names_node; string shape_and_slices_node; for (const auto& node : input_graph_def.node()) { - if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) { + if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) { + tensor_names_node = node.input(1); shape_and_slices_node = node.input(2); break; } } + + int offset = -1; + for (const auto& node : input_graph_def.node()) { + if (node.name() == tensor_names_node) { + Tensor tensor_names_tensor; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor)); + const auto& tensor_names_value = tensor_names_tensor.flat(); + for (int i = 0; i < tensor_names_value.size(); i++) { + if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) { + offset = i; + break; + } + } + } + } + if (offset == -1) { + return errors::Internal("Unable to find RestoreV2 entry for variable: ", + target_name); + } for (const auto& node : input_graph_def.node()) { if (node.name() == shape_and_slices_node) { Tensor shape_and_slices_tensor; TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor)); const auto& shape_and_slices_value = shape_and_slices_tensor.flat(); - *shape_slice_string = shape_and_slices_value(0); + *shape_slice_string = shape_and_slices_value(offset); return Status::OK(); } } - return errors::Internal("Unable to find slice for variable: ", tensor_name); -} - -string GetMonolithicTensorKey(const string& tensor_slice_name) { - std::vector names = Split(tensor_slice_name, "/"); - CHECK_GE(names.size(), 2); - CHECK(StringPiece(names[names.size() - 1]).starts_with("part_")); - - // Remove the "part_x" suffix - names.pop_back(); - return Join(names, "/"); + return errors::Internal("Unable to find slice for variable: ", target_name); } Status ReadTensorFromCheckpoint( diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index 203ed3e..d41321c 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -106,11 +106,15 @@ class SparsifyGatherTest : public ::testing::Test { NodeDef* save_const_node = CreateNode("save/Const", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({1})); + test::FillValues(&tensor_names_values, {"w"}); NodeDef* tensor_names_node = CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); + NodeDef* tensor_shapes_slices_node = CreateNode( "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); - Tensor shapes_slices_val(DT_STRING, TensorShape({1})); shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; SetNodeTensorAttr("value", shapes_slices_val, @@ -310,6 +314,29 @@ class SparsifyGatherTest : public ::testing::Test { SetNodeTensorAttr("value", weights, w_node1); SetNodeTensorAttr("value", weights, w_node2); } else { + NodeDef* save_const_node = + CreateNode("save/Const", "Const", {}, &graph_def); + + NodeDef* tensor_names_node = + CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); + Tensor tensor_names_values(DT_STRING, TensorShape({2})); + test::FillValues(&tensor_names_values, {"w1", "w2"}); + SetNodeTensorAttr("value", tensor_names_values, + tensor_names_node); + + NodeDef* tensor_shapes_slices_node = CreateNode( + "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); + Tensor shapes_slices_val(DT_STRING, TensorShape({2})); + shapes_slices_val.flat()(0) = "4 1 0,4:0,1"; + shapes_slices_val.flat()(1) = "4 1 0,4:0,1"; + SetNodeTensorAttr("value", shapes_slices_val, + tensor_shapes_slices_node); + + NodeDef* restore_node = CreateNode( + "save/RestoreV2", "RestoreV2", + {save_const_node, tensor_names_node, tensor_shapes_slices_node}, + &graph_def); + w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def); zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor", @@ -321,23 +348,7 @@ class SparsifyGatherTest : public ::testing::Test { assign_node1 = CreateNode("w1/part_1/Assign", "Assign", {w_node1, zeros_node1}, &graph_def); - NodeDef* save_const_node = - CreateNode("save/Const", "Const", {}, &graph_def); - NodeDef* tensor_names_node1 = - CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node1 = CreateNode( - "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def); - - Tensor shapes_slices_val1(DT_STRING, TensorShape({1})); - shapes_slices_val1.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val1, - tensor_shapes_slices_node1); - - NodeDef* restore_node1 = CreateNode( - "save/RestoreV2", "RestoreV2", - {save_const_node, tensor_names_node1, tensor_shapes_slices_node1}, - &graph_def); - CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def); + CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def); w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def); zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor", @@ -349,21 +360,7 @@ class SparsifyGatherTest : public ::testing::Test { assign_node2 = CreateNode("w2/part_1/Assign", "Assign", {w_node2, zeros_node2}, &graph_def); - NodeDef* tensor_names_node2 = - CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def); - NodeDef* tensor_shapes_slices_node2 = CreateNode( - "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def); - - Tensor shapes_slices_val2(DT_STRING, TensorShape({1})); - shapes_slices_val2.flat()(0) = "4 1 0,4:0,1"; - SetNodeTensorAttr("value", shapes_slices_val2, - tensor_shapes_slices_node2); - - NodeDef* restore_node2 = CreateNode( - "save/RestoreV2_1", "RestoreV2", - {save_const_node, tensor_names_node2, tensor_shapes_slices_node2}, - &graph_def); - CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2}, + CreateNode("save/Assign_1", "Assign", {w_node2, restore_node}, &graph_def); BundleWriter writer(Env::Default(), checkpoint_path);