Passing a name with a trailing '/' to import_graph_def causes that
name to be used as-is (i.e. it is not appended to the existing name
scope and not de-duped with any existing name scopes. This is in order
to re-use an existing name scope). This didn't work with the C API
enabled because it was set to always have the C API uniquify the
prefix.
The fix is to not uniquify the prefix, since calling name_scope in
import_graph_def already has the logic to uniquify the prefix if
necessary. I'm not sure why I thought we needed the C API to do this
to being with.
In addition, this changes the graph_constructor.cc logic to uniquify
names if the prefix cannot be guaranteed unique (see the new test case
in graph_constructor_test.cc for why/when this is necessary).
PiperOrigin-RevId:
185215326
return errors::InvalidArgument("Imported node name prefix '", prefix_,
"' would lead to invalid node names");
}
- if (NameExistsInGraph(prefix_no_slash)) {
- if (opts_.uniquify_prefix) {
- prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
- } else {
- return errors::InvalidArgument("Import node name prefix '",
- prefix_no_slash,
- "' conflicts with "
- "name already used in the graph");
- }
+ if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
+ prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
}
}
return Status::OK();
if (opts_.importing) {
if (!prefix_.empty()) {
AddPrefixToNodeDef(input_already_exists, &imported_node_def);
- } else if (opts_.uniquify_names) {
+ }
+ // Note: no need to uniquify names if the prefix already guarantees
+ // uniqueness
+ if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
UniquifyNames(input_already_exists, &imported_node_def);
}
TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
- // Import with an already-used prefix
+ // Import with an already-used prefix and uniquify_prefix = true
opts.prefix = "A";
opts.uniquify_prefix = true;
results = ImportGraphDefResults();
EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
// Create B_3 node to keep the A/B numbering in sync
- opts = ImportGraphDefOptions();
ExpectOK("node { name: 'B_3' op: 'TestInput' }");
+ // Import with an already-used prefix and uniquify_prefix = false
+ opts.uniquify_prefix = false;
+ results = ImportGraphDefResults();
+ ExpectOK(graph_def_str, opts, &refiner, &results);
+
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_nodes[0]->name(), "A/A");
+ EXPECT_EQ(results.return_nodes[1]->name(), "A/B");
+ EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A");
+
+ // Repeat the same import
+ results = ImportGraphDefResults();
+ ExpectOK(graph_def_str, opts, &refiner, &results);
+
+ ASSERT_EQ(results.return_nodes.size(), 2);
+ EXPECT_EQ(results.return_nodes[0]->name(), "A/A_1");
+ EXPECT_EQ(results.return_nodes[1]->name(), "A/B_1");
+ EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A_1:0");
+
// Import with existing de-duped node names
opts = ImportGraphDefOptions();
opts.uniquify_names = true;
"""Populates the TF_ImportGraphDefOptions `options`."""
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
- c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)
for input_src, input_dst in input_map.items():
input_src = compat.as_str(input_src)
self.assertEqual(b3.name, "A_3/B")
self.assertEqual(list(b3.inputs), [a3.outputs[0]])
+ # Import with an already-used name but with a '/' to indicate an
+ # "absolute" name scope (see the Graph.name_scope docstring).
+ a_a, a_b = importer.import_graph_def(
+ graph_def,
+ return_elements=["A", "B"],
+ name="A/")
+ self.assertEqual(a_a.name, "A/A")
+ self.assertEqual(a_b.name, "A/B")
+ self.assertEqual(list(a_b.inputs), [a_a.outputs[0]])
+
+ # Repeat the same import.
+ a_a1, a_b1 = importer.import_graph_def(
+ graph_def,
+ return_elements=["A", "B"],
+ name="A/")
+ self.assertEqual(a_a1.name, "A/A_1")
+ self.assertEqual(a_b1.name, "A/B_1")
+ self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]])
+
# Import with existing de-duped node names
a1_1, b1_1 = importer.import_graph_def(
self._MakeGraphDef("""