Make graph's name scope thread local so that two threads opening the same scope don...
authorPriya Gupta <priyag@google.com>
Wed, 21 Mar 2018 06:07:37 +0000 (23:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 06:10:43 +0000 (23:10 -0700)
PiperOrigin-RevId: 189865854

tensorflow/python/framework/ops.py
tensorflow/python/framework/ops_test.py

index 4be2e2c..50a1d3f 100644 (file)
@@ -2727,8 +2727,6 @@ class Graph(object):
     self._next_id_counter = 0  # GUARDED_BY(self._lock)
     self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
     self._version = 0  # GUARDED_BY(self._lock)
-    # Current name stack: uniquified names
-    self._name_stack = ""
     # Maps a name used in the graph to the next id to use for that name.
     self._names_in_use = {}
     self._stack_state_is_thread_local = False
@@ -3907,6 +3905,17 @@ class Graph(object):
     finally:
       self._default_original_op = old_original_op
 
+  @property
+  def _name_stack(self):
+    # This may be called from a thread where name_stack doesn't yet exist.
+    if not hasattr(self._thread_local, "_name_stack"):
+      self._thread_local._name_stack = ""
+    return self._thread_local._name_stack
+
+  @_name_stack.setter
+  def _name_stack(self, name_stack):
+    self._thread_local._name_stack = name_stack
+
   # pylint: disable=g-doc-return-or-yield,line-too-long
   @tf_contextlib.contextmanager
   def name_scope(self, name):
index d96e070..aa51391 100644 (file)
@@ -1556,6 +1556,35 @@ class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
              input: "^ColocateWithMe_2" }
     """, gd)
 
+  def testNameStack(self):
+
+    class NameSettingThread(self.TestThread):
+
+      def run(self):
+        with g.name_scope("foo"):
+          op1 = g.create_op("FloatOutput", [], [dtypes.float32])
+          self.has_mutated_graph.set()
+          self.should_continue.wait()
+          self.should_continue.clear()
+          op2 = g.create_op("FloatOutput", [], [dtypes.float32])
+          self.result = (op1, op2)
+
+    g = ops.Graph()
+    threads = [NameSettingThread(g, i) for i in range(3)]
+    for t in threads:
+      t.start()
+      t.has_mutated_graph.wait()
+      t.has_mutated_graph.clear()
+
+    for t in threads:
+      t.should_continue.set()
+      t.join()
+
+    suffixes = ["", "_1", "_2"]
+    for t, s in zip(threads, suffixes):
+      self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name)
+      self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name)
+
 
 @test_util.with_c_api
 class ObjectWithName(object):