import_graph_def: support "absolute" names with the C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Sat, 10 Feb 2018 01:14:30 +0000 (17:14 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 10 Feb 2018 01:18:33 +0000 (17:18 -0800)
Passing a name with a trailing '/' to import_graph_def causes that
name to be used as-is (i.e. it is not appended to the existing name
scope and not de-duped with any existing name scopes. This is in order
to re-use an existing name scope). This didn't work with the C API
enabled because it was set to always have the C API uniquify the
prefix.

The fix is to not uniquify the prefix, since calling name_scope in
import_graph_def already has the logic to uniquify the prefix if
necessary. I'm not sure why I thought we needed the C API to do this
to being with.

In addition, this changes the graph_constructor.cc logic to uniquify
names if the prefix cannot be guaranteed unique (see the new test case
in graph_constructor_test.cc for why/when this is necessary).

PiperOrigin-RevId: 185215326

tensorflow/core/graph/graph_constructor.cc
tensorflow/core/graph/graph_constructor_test.cc
tensorflow/python/framework/importer.py
tensorflow/python/framework/importer_test.py

index 2a52c75..0629ff3 100644 (file)
@@ -374,15 +374,8 @@ Status GraphConstructor::EnsureNoNameCollisions() {
       return errors::InvalidArgument("Imported node name prefix '", prefix_,
                                      "' would lead to invalid node names");
     }
-    if (NameExistsInGraph(prefix_no_slash)) {
-      if (opts_.uniquify_prefix) {
-        prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
-      } else {
-        return errors::InvalidArgument("Import node name prefix '",
-                                       prefix_no_slash,
-                                       "' conflicts with "
-                                       "name already used in the graph");
-      }
+    if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
+      prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
     }
   }
   return Status::OK();
@@ -990,7 +983,10 @@ Status GraphConstructor::Convert() {
     if (opts_.importing) {
       if (!prefix_.empty()) {
         AddPrefixToNodeDef(input_already_exists, &imported_node_def);
-      } else if (opts_.uniquify_names) {
+      }
+      // Note: no need to uniquify names if the prefix already guarantees
+      // uniqueness
+      if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
         UniquifyNames(input_already_exists, &imported_node_def);
       }
       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
index c59e478..963c1dc 100644 (file)
@@ -1834,7 +1834,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
   EXPECT_EQ(results.return_nodes[1]->name(), "B_2");
   EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_2:0");
 
-  // Import with an already-used prefix
+  // Import with an already-used prefix and uniquify_prefix = true
   opts.prefix = "A";
   opts.uniquify_prefix = true;
   results = ImportGraphDefResults();
@@ -1846,9 +1846,27 @@ TEST_F(GraphConstructorTest, ImportGraphDef_UniquifyNames) {
   EXPECT_EQ(results.return_nodes[1]->def().input(0), "A_3/A");
 
   // Create B_3 node to keep the A/B numbering in sync
-  opts = ImportGraphDefOptions();
   ExpectOK("node { name: 'B_3' op: 'TestInput' }");
 
+  // Import with an already-used prefix and uniquify_prefix = false
+  opts.uniquify_prefix = false;
+  results = ImportGraphDefResults();
+  ExpectOK(graph_def_str, opts, &refiner, &results);
+
+  ASSERT_EQ(results.return_nodes.size(), 2);
+  EXPECT_EQ(results.return_nodes[0]->name(), "A/A");
+  EXPECT_EQ(results.return_nodes[1]->name(), "A/B");
+  EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A");
+
+  // Repeat the same import
+  results = ImportGraphDefResults();
+  ExpectOK(graph_def_str, opts, &refiner, &results);
+
+  ASSERT_EQ(results.return_nodes.size(), 2);
+  EXPECT_EQ(results.return_nodes[0]->name(), "A/A_1");
+  EXPECT_EQ(results.return_nodes[1]->name(), "A/B_1");
+  EXPECT_EQ(results.return_nodes[1]->def().input(0), "A/A_1:0");
+
   // Import with existing de-duped node names
   opts = ImportGraphDefOptions();
   opts.uniquify_names = true;
index cc8f239..6ecc1a4 100644 (file)
@@ -270,7 +270,6 @@ def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
   """Populates the TF_ImportGraphDefOptions `options`."""
   c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
   c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
-  c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)
 
   for input_src, input_dst in input_map.items():
     input_src = compat.as_str(input_src)
index acaec37..bf5d9fe 100644 (file)
@@ -154,6 +154,25 @@ class ImportGraphDefTest(test.TestCase):
       self.assertEqual(b3.name, "A_3/B")
       self.assertEqual(list(b3.inputs), [a3.outputs[0]])
 
+      # Import with an already-used name but with a '/' to indicate an
+      # "absolute" name scope (see the Graph.name_scope docstring).
+      a_a, a_b = importer.import_graph_def(
+          graph_def,
+          return_elements=["A", "B"],
+          name="A/")
+      self.assertEqual(a_a.name, "A/A")
+      self.assertEqual(a_b.name, "A/B")
+      self.assertEqual(list(a_b.inputs), [a_a.outputs[0]])
+
+      # Repeat the same import.
+      a_a1, a_b1 = importer.import_graph_def(
+          graph_def,
+          return_elements=["A", "B"],
+          name="A/")
+      self.assertEqual(a_a1.name, "A/A_1")
+      self.assertEqual(a_b1.name, "A/B_1")
+      self.assertEqual(list(a_b1.inputs), [a_a1.outputs[0]])
+
       # Import with existing de-duped node names
       a1_1, b1_1 = importer.import_graph_def(
           self._MakeGraphDef("""