SetNodeTensorAttr<float>("value", tensor, node_def);
}
+string GetMonolithicTensorKey(const string& tensor_slice_name) {
+ std::vector<string> names = Split(tensor_slice_name, "/");
+ if (StringPiece(names[names.size() - 1]).starts_with("part_")) {
+ CHECK_GE(names.size(), 2);
+ names.pop_back();
+ }
+ return Join(names, "/");
+}
+
Status ObtainTensorSlice(const GraphDef& input_graph_def,
- const string& tensor_name,
+ const string& target_name,
string* shape_slice_string) {
string restore_node_name;
for (const auto& node : input_graph_def.node()) {
if (node_name_parts.size() == 2 &&
StringPiece(node_name_parts[0]).starts_with("save") &&
StringPiece(node_name_parts[1]).starts_with("Assign") &&
- node.input(0) == tensor_name) {
+ node.input(0) == target_name) {
restore_node_name = node.input(1);
break;
}
}
+
+ std::vector<string> restore_node_parts = Split(restore_node_name, ":");
+ CHECK_LE(restore_node_parts.size(), 2);
+ string tensor_names_node;
string shape_and_slices_node;
for (const auto& node : input_graph_def.node()) {
- if ((node.name() == restore_node_name) && (node.op() == "RestoreV2")) {
+ if ((node.name() == restore_node_parts[0]) && (node.op() == "RestoreV2")) {
+ tensor_names_node = node.input(1);
shape_and_slices_node = node.input(2);
break;
}
}
+
+ int offset = -1;
+ for (const auto& node : input_graph_def.node()) {
+ if (node.name() == tensor_names_node) {
+ Tensor tensor_names_tensor;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor_names_tensor));
+ const auto& tensor_names_value = tensor_names_tensor.flat<string>();
+ for (int i = 0; i < tensor_names_value.size(); i++) {
+ if (tensor_names_value(i) == GetMonolithicTensorKey(target_name)) {
+ offset = i;
+ break;
+ }
+ }
+ }
+ }
+ if (offset == -1) {
+ return errors::Internal("Unable to find RestoreV2 entry for variable: ",
+ target_name);
+ }
for (const auto& node : input_graph_def.node()) {
if (node.name() == shape_and_slices_node) {
Tensor shape_and_slices_tensor;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &shape_and_slices_tensor));
const auto& shape_and_slices_value =
shape_and_slices_tensor.flat<string>();
- *shape_slice_string = shape_and_slices_value(0);
+ *shape_slice_string = shape_and_slices_value(offset);
return Status::OK();
}
}
- return errors::Internal("Unable to find slice for variable: ", tensor_name);
-}
-
-string GetMonolithicTensorKey(const string& tensor_slice_name) {
- std::vector<string> names = Split(tensor_slice_name, "/");
- CHECK_GE(names.size(), 2);
- CHECK(StringPiece(names[names.size() - 1]).starts_with("part_"));
-
- // Remove the "part_x" suffix
- names.pop_back();
- return Join(names, "/");
+ return errors::Internal("Unable to find slice for variable: ", target_name);
}
Status ReadTensorFromCheckpoint(
NodeDef* save_const_node =
CreateNode("save/Const", "Const", {}, &graph_def);
+ Tensor tensor_names_values(DT_STRING, TensorShape({1}));
+ test::FillValues<string>(&tensor_names_values, {"w"});
NodeDef* tensor_names_node =
CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+ SetNodeTensorAttr<string>("value", tensor_names_values,
+ tensor_names_node);
+
NodeDef* tensor_shapes_slices_node = CreateNode(
"save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
Tensor shapes_slices_val(DT_STRING, TensorShape({1}));
shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
SetNodeTensorAttr<string>("value", shapes_slices_val,
SetNodeTensorAttr<float>("value", weights, w_node1);
SetNodeTensorAttr<float>("value", weights, w_node2);
} else {
+ NodeDef* save_const_node =
+ CreateNode("save/Const", "Const", {}, &graph_def);
+
+ NodeDef* tensor_names_node =
+ CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
+ Tensor tensor_names_values(DT_STRING, TensorShape({2}));
+ test::FillValues<string>(&tensor_names_values, {"w1", "w2"});
+ SetNodeTensorAttr<string>("value", tensor_names_values,
+ tensor_names_node);
+
+ NodeDef* tensor_shapes_slices_node = CreateNode(
+ "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
+ Tensor shapes_slices_val(DT_STRING, TensorShape({2}));
+ shapes_slices_val.flat<string>()(0) = "4 1 0,4:0,1";
+ shapes_slices_val.flat<string>()(1) = "4 1 0,4:0,1";
+ SetNodeTensorAttr<string>("value", shapes_slices_val,
+ tensor_shapes_slices_node);
+
+ NodeDef* restore_node = CreateNode(
+ "save/RestoreV2", "RestoreV2",
+ {save_const_node, tensor_names_node, tensor_shapes_slices_node},
+ &graph_def);
+
w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
{w_node1, zeros_node1}, &graph_def);
- NodeDef* save_const_node =
- CreateNode("save/Const", "Const", {}, &graph_def);
- NodeDef* tensor_names_node1 =
- CreateNode("save/RestoreV2/tensor_names", "Const", {}, &graph_def);
- NodeDef* tensor_shapes_slices_node1 = CreateNode(
- "save/RestoreV2/shape_and_slices", "Const", {}, &graph_def);
-
- Tensor shapes_slices_val1(DT_STRING, TensorShape({1}));
- shapes_slices_val1.flat<string>()(0) = "4 1 0,4:0,1";
- SetNodeTensorAttr<string>("value", shapes_slices_val1,
- tensor_shapes_slices_node1);
-
- NodeDef* restore_node1 = CreateNode(
- "save/RestoreV2", "RestoreV2",
- {save_const_node, tensor_names_node1, tensor_shapes_slices_node1},
- &graph_def);
- CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def);
+ CreateNode("save/Assign", "Assign", {w_node1, restore_node}, &graph_def);
w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
{w_node2, zeros_node2}, &graph_def);
- NodeDef* tensor_names_node2 =
- CreateNode("save/RestoreV2_1/tensor_names", "Const", {}, &graph_def);
- NodeDef* tensor_shapes_slices_node2 = CreateNode(
- "save/RestoreV2_1/shape_and_slices", "Const", {}, &graph_def);
-
- Tensor shapes_slices_val2(DT_STRING, TensorShape({1}));
- shapes_slices_val2.flat<string>()(0) = "4 1 0,4:0,1";
- SetNodeTensorAttr<string>("value", shapes_slices_val2,
- tensor_shapes_slices_node2);
-
- NodeDef* restore_node2 = CreateNode(
- "save/RestoreV2_1", "RestoreV2",
- {save_const_node, tensor_names_node2, tensor_shapes_slices_node2},
- &graph_def);
- CreateNode("save/Assign_1", "Assign", {w_node2, restore_node2},
+ CreateNode("save/Assign_1", "Assign", {w_node2, restore_node},
&graph_def);
BundleWriter writer(Env::Default(), checkpoint_path);