Make variable scope and scope counts local to current thread so that they work correc...
authorPriya Gupta <priyag@google.com>
Wed, 21 Mar 2018 04:39:16 +0000 (21:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Mar 2018 04:42:03 +0000 (21:42 -0700)
PiperOrigin-RevId: 189860229

tensorflow/contrib/eager/python/network.py
tensorflow/python/kernel_tests/variable_scope_test.py
tensorflow/python/ops/template.py
tensorflow/python/ops/variable_scope.py

index 4c93771..e55a927 100644 (file)
@@ -149,7 +149,7 @@ class Network(base.Layer):
     # check we might have name collisions if the parent scope on init gets
     # closed before build is called.
     self._variable_scope_counts_on_init = (
-        variable_scope._get_default_variable_store().variable_scopes_count)
+        variable_scope.get_variable_scope_store().variable_scopes_count)
 
   def _name_scope_name(self, current_variable_scope):
     """Overrides Layer op naming to match variable naming."""
index 531d0cd..86ab9fb 100644 (file)
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import gc
+import threading
 
 import numpy
 
@@ -1349,5 +1350,91 @@ class PartitionInfoTest(test.TestCase):
     self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
 
 
+class VariableScopeMultithreadedTest(test.TestCase):
+
+  def testTwoThreadsDisjointScopeEntry(self):
+
+    def thread_fn(i, graph):
+      with graph.as_default():
+        with variable_scope.variable_scope("foo"):
+          if i == 0:
+            v = variable_scope.get_variable("v", [])
+            self.assertEquals("foo/v:0", v.name)
+          else:
+            # Any thread after the first one should fail to create variable
+            # with the same name.
+            with self.assertRaises(ValueError):
+              variable_scope.get_variable("v", [])
+
+    graph = ops.get_default_graph()
+    threads = [
+        threading.Thread(target=thread_fn, args=(i, graph,)) for i in range(2)]
+
+    threads[0].start()
+    # Allow thread 0 to finish before starting thread 1.
+    threads[0].join()
+    threads[1].start()
+    threads[1].join()
+
+  def testTwoThreadsNestedScopeEntry(self):
+
+    def thread_fn(i, graph, run_event, pause_event):
+      with graph.as_default():
+        with variable_scope.variable_scope("foo"):
+          if i == 0:
+            v = variable_scope.get_variable("v", [])
+            self.assertEquals("foo/v:0", v.name)
+          else:
+            # Any thread after the first one should fail to create variable
+            # with the same name.
+            with self.assertRaises(ValueError):
+              variable_scope.get_variable("v", [])
+          pause_event.set()
+          run_event.wait()
+
+    graph = ops.get_default_graph()
+    run_events = [threading.Event() for _ in range(2)]
+    pause_events = [threading.Event() for _ in range(2)]
+    threads = [
+        threading.Thread(
+            target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
+        for i in range(2)
+    ]
+
+    # Start first thread.
+    threads[0].start()
+    pause_events[0].wait()
+    # Start next thread once the first thread has paused.
+    threads[1].start()
+    pause_events[1].wait()
+    # Resume both threads.
+    run_events[0].set()
+    run_events[1].set()
+    threads[0].join()
+    threads[1].join()
+
+  def testReenterMainScope(self):
+
+    def thread_fn(graph, main_thread_scope):
+      with graph.as_default():
+        # Variable created with main scope will have prefix "main".
+        with variable_scope.variable_scope(main_thread_scope):
+          with variable_scope.variable_scope("foo"):
+            v = variable_scope.get_variable("v", [])
+            self.assertEquals("main/foo/v:0", v.name)
+
+        # Variable created outside main scope will not have prefix "main".
+        with variable_scope.variable_scope("bar"):
+          v = variable_scope.get_variable("v", [])
+          self.assertEquals("bar/v:0", v.name)
+
+    graph = ops.get_default_graph()
+    with variable_scope.variable_scope("main") as main_thread_scope:
+      thread = threading.Thread(
+          target=thread_fn, args=(graph, main_thread_scope))
+      thread.start()
+      thread.join()
+
+
 if __name__ == "__main__":
   test.main()
index 0a391d8..0294ece 100644 (file)
@@ -583,7 +583,7 @@ class _EagerTemplateVariableStore(object):
       if self._variable_scope_name is None:
         raise RuntimeError("A variable scope must be set before an "
                            "_EagerTemplateVariableStore object exits.")
-      self._eager_variable_store._store.close_variable_subscopes(  # pylint: disable=protected-access
+      variable_scope.get_variable_scope_store().close_variable_subscopes(
           self._variable_scope_name)
 
   def _variables_in_scope(self, variable_list):
index c1af8ff..c35735c 100644 (file)
@@ -24,6 +24,7 @@ import copy
 import enum  # pylint: disable=g-bad-import-order
 import functools
 import sys
+import threading
 import traceback
 
 import six
@@ -211,23 +212,8 @@ class _VariableStore(object):
     """Create a variable store."""
     self._vars = {}  # A dictionary of the stored TensorFlow variables.
     self._partitioned_vars = {}  # A dict of the stored PartitionedVariables.
-    self.variable_scopes_count = {}  # Count re-used variable scopes.
     self._store_eager_variables = False
 
-  def open_variable_scope(self, scope_name):
-    if scope_name in self.variable_scopes_count:
-      self.variable_scopes_count[scope_name] += 1
-    else:
-      self.variable_scopes_count[scope_name] = 1
-
-  def close_variable_subscopes(self, scope_name):
-    for k in self.variable_scopes_count:
-      if not scope_name or k.startswith(scope_name + "/"):
-        self.variable_scopes_count[k] = 0
-
-  def variable_scope_count(self, scope_name):
-    return self.variable_scopes_count.get(scope_name, 0)
-
   def get_variable(self, name, shape=None, dtype=dtypes.float32,
                    initializer=None, regularizer=None, reuse=None,
                    trainable=True, collections=None, caching_device=None,
@@ -1160,18 +1146,49 @@ class VariableScope(object):
 
 
 _VARSTORE_KEY = ("__variable_store",)
-_VARSCOPE_KEY = ("__varscope",)
+_VARSCOPESTORE_KEY = ("__varscope",)
+
+
+class _VariableScopeStore(threading.local):
+  """A thread local store for the current variable scope and scope counts."""
+
+  def __init__(self):
+    super(_VariableScopeStore, self).__init__()
+    self.current_scope = VariableScope(False)
+    self.variable_scopes_count = {}
+
+  def open_variable_scope(self, scope_name):
+    if scope_name in self.variable_scopes_count:
+      self.variable_scopes_count[scope_name] += 1
+    else:
+      self.variable_scopes_count[scope_name] = 1
+
+  def close_variable_subscopes(self, scope_name):
+    for k in self.variable_scopes_count:
+      if not scope_name or k.startswith(scope_name + "/"):
+        self.variable_scopes_count[k] = 0
+
+  def variable_scope_count(self, scope_name):
+    return self.variable_scopes_count.get(scope_name, 0)
+
+
+def get_variable_scope_store():
+  """Returns the variable scope store for current thread."""
+  scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
+
+  if not scope_store:
+    scope_store = _VariableScopeStore()
+    ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
+  else:
+    scope_store = scope_store[0]
+
+  return scope_store
 
 
 @tf_export("get_variable_scope")
 def get_variable_scope():
   """Returns the current variable scope."""
-  scope = ops.get_collection(_VARSCOPE_KEY)
-  if scope:  # This collection has at most 1 element, the default scope at [0].
-    return scope[0]
-  scope = VariableScope(False)
-  ops.add_to_collection(_VARSCOPE_KEY, scope)
-  return scope
+  return get_variable_scope_store().current_scope
 
 
 def _get_default_variable_store():
@@ -1575,10 +1592,8 @@ class _pure_variable_scope(object):  # pylint: disable=invalid-name
     self._dtype = dtype
     self._use_resource = use_resource
     self._constraint = constraint
-    get_variable_scope()  # Ensure that a default exists, then get a pointer.
-    # Get the reference to the collection as we want to modify it in place.
-    self._default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
     self._var_store = _get_default_variable_store()
+    self._var_scope_store = get_variable_scope_store()
     if isinstance(self._name_or_scope, VariableScope):
       self._new_name = self._name_or_scope.name
       name_scope = self._name_or_scope._name_scope  # pylint: disable=protected-access
@@ -1626,10 +1641,11 @@ class _pure_variable_scope(object):  # pylint: disable=invalid-name
         a reuse scope, or if reuse is not `None` or `True`.
       TypeError: when the types of some arguments are not appropriate.
     """
-    self._old = self._default_varscope[0]
+    self._old = self._var_scope_store.current_scope
     if isinstance(self._name_or_scope, VariableScope):
-      self._var_store.open_variable_scope(self._new_name)
-      self._old_subscopes = copy.copy(self._var_store.variable_scopes_count)
+      self._var_scope_store.open_variable_scope(self._new_name)
+      self._old_subscopes = copy.copy(
+          self._var_scope_store.variable_scopes_count)
       variable_scope_object = self._cached_variable_scope_object
     else:
       # Handler for the case when we just prolong current variable scope.
@@ -1672,17 +1688,17 @@ class _pure_variable_scope(object):  # pylint: disable=invalid-name
         variable_scope_object.set_dtype(self._dtype)
       if self._use_resource is not None:
         variable_scope_object.set_use_resource(self._use_resource)
-      self._var_store.open_variable_scope(self._new_name)
-    self._default_varscope[0] = variable_scope_object
+      self._var_scope_store.open_variable_scope(self._new_name)
+    self._var_scope_store.current_scope = variable_scope_object
     return variable_scope_object
 
   def __exit__(self, type_arg, value_arg, traceback_arg):
     # If jumping out from a non-prolonged scope, restore counts.
     if isinstance(self._name_or_scope, VariableScope):
-      self._var_store.variable_scopes_count = self._old_subscopes
+      self._var_scope_store.variable_scopes_count = self._old_subscopes
     else:
-      self._var_store.close_variable_subscopes(self._new_name)
-    self._default_varscope[0] = self._old
+      self._var_scope_store.close_variable_subscopes(self._new_name)
+    self._var_scope_store.current_scope = self._old
 
 
 def _maybe_wrap_custom_getter(custom_getter, old_getter):
@@ -1707,13 +1723,13 @@ def _maybe_wrap_custom_getter(custom_getter, old_getter):
 
 def _get_unique_variable_scope(prefix):
   """Get a name with the given prefix unique in the current variable scope."""
-  var_store = _get_default_variable_store()
+  var_scope_store = get_variable_scope_store()
   current_scope = get_variable_scope()
   name = current_scope.name + "/" + prefix if current_scope.name else prefix
-  if var_store.variable_scope_count(name) == 0:
+  if var_scope_store.variable_scope_count(name) == 0:
     return prefix
   idx = 1
-  while var_store.variable_scope_count(name + ("_%d" % idx)) > 0:
+  while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
     idx += 1
   return prefix + ("_%d" % idx)
 
@@ -1729,9 +1745,10 @@ class variable_scope(object):
   graph, ensures that graph is the default graph, and pushes a name scope and a
   variable scope.
 
-  If `name_or_scope` is not None, it is used as is. If `scope` is None, then
-  `default_name` is used.  In that case, if the same name has been previously
-  used in the same scope, it will be made unique by appending `_N` to it.
+  If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
+  then `default_name` is used.  In that case, if the same name has been
+  previously used in the same scope, it will be made unique by appending `_N`
+  to it.
 
   Variable scope allows you to create new variables and to share already created
   ones while providing checks to not create or share by accident. For details,
@@ -1810,6 +1827,32 @@ class variable_scope(object):
   discouraged) to pass False to the reuse argument, yielding undocumented
   behaviour slightly different from None. Starting at 1.1.0 passing None and
   False as reuse has exactly the same effect.
+
+  A note about using variable scopes in multi-threaded environment: Variable
+  scopes are thread local, so one thread will not see another thread's current
+  scope. Also, when using `default_name`, unique scopes names are also generated
+  only on a per thread basis. If the same name was used within a different
+  thread, that doesn't prevent a new thread from creating the same scope.
+  However, the underlying variable store is shared across threads (within the
+  same graph). As such, if another thread tries to create a new variable with
+  the same name as a variable created by a previous thread, it will fail unless
+  reuse is True.
+
+  Further, each thread starts with an empty variable scope. So if you wish to
+  preserve name prefixes from a scope from the main thread, you should capture
+  the main thread's scope and re-enter it in each thread. For e.g.
+
+  ```
+  main_thread_scope = variable_scope.get_variable_scope()
+
+  # Thread's target function:
+  def thread_target_fn(captured_scope):
+    with variable_scope.variable_scope(captured_scope):
+      # .... regular code for this thread
+
+
+  thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
+  ```
   """
 
   def __init__(self,