from __future__ import print_function
import gc
+import threading
import numpy
self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
+class VariableScopeMultithreadedTest(test.TestCase):
+
+ def testTwoThreadsDisjointScopeEntry(self):
+
+ def thread_fn(i, graph):
+ with graph.as_default():
+ with variable_scope.variable_scope("foo"):
+ if i == 0:
+ v = variable_scope.get_variable("v", [])
+ self.assertEquals("foo/v:0", v.name)
+ else:
+ # Any thread after the first one should fail to create variable
+ # with the same name.
+ with self.assertRaises(ValueError):
+ variable_scope.get_variable("v", [])
+
+ graph = ops.get_default_graph()
+ threads = [
+ threading.Thread(target=thread_fn, args=(i, graph,)) for i in range(2)]
+
+ threads[0].start()
+ # Allow thread 0 to finish before starting thread 1.
+ threads[0].join()
+ threads[1].start()
+ threads[1].join()
+
+ def testTwoThreadsNestedScopeEntry(self):
+
+ def thread_fn(i, graph, run_event, pause_event):
+ with graph.as_default():
+ with variable_scope.variable_scope("foo"):
+ if i == 0:
+ v = variable_scope.get_variable("v", [])
+ self.assertEquals("foo/v:0", v.name)
+ else:
+ # Any thread after the first one should fail to create variable
+ # with the same name.
+ with self.assertRaises(ValueError):
+ variable_scope.get_variable("v", [])
+ pause_event.set()
+ run_event.wait()
+
+ graph = ops.get_default_graph()
+ run_events = [threading.Event() for _ in range(2)]
+ pause_events = [threading.Event() for _ in range(2)]
+ threads = [
+ threading.Thread(
+ target=thread_fn, args=(i, graph, run_events[i], pause_events[i]))
+ for i in range(2)
+ ]
+
+ # Start first thread.
+ threads[0].start()
+ pause_events[0].wait()
+ # Start next thread once the first thread has paused.
+ threads[1].start()
+ pause_events[1].wait()
+ # Resume both threads.
+ run_events[0].set()
+ run_events[1].set()
+ threads[0].join()
+ threads[1].join()
+
+ def testReenterMainScope(self):
+
+ def thread_fn(graph, main_thread_scope):
+ with graph.as_default():
+ # Variable created with main scope will have prefix "main".
+ with variable_scope.variable_scope(main_thread_scope):
+ with variable_scope.variable_scope("foo"):
+ v = variable_scope.get_variable("v", [])
+ self.assertEquals("main/foo/v:0", v.name)
+
+ # Variable created outside main scope will not have prefix "main".
+ with variable_scope.variable_scope("bar"):
+ v = variable_scope.get_variable("v", [])
+ self.assertEquals("bar/v:0", v.name)
+
+ graph = ops.get_default_graph()
+ with variable_scope.variable_scope("main") as main_thread_scope:
+ thread = threading.Thread(
+ target=thread_fn, args=(graph, main_thread_scope))
+ thread.start()
+ thread.join()
+
+
if __name__ == "__main__":
test.main()
import enum # pylint: disable=g-bad-import-order
import functools
import sys
+import threading
import traceback
import six
"""Create a variable store."""
self._vars = {} # A dictionary of the stored TensorFlow variables.
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
- self.variable_scopes_count = {} # Count re-used variable scopes.
self._store_eager_variables = False
- def open_variable_scope(self, scope_name):
- if scope_name in self.variable_scopes_count:
- self.variable_scopes_count[scope_name] += 1
- else:
- self.variable_scopes_count[scope_name] = 1
-
- def close_variable_subscopes(self, scope_name):
- for k in self.variable_scopes_count:
- if not scope_name or k.startswith(scope_name + "/"):
- self.variable_scopes_count[k] = 0
-
- def variable_scope_count(self, scope_name):
- return self.variable_scopes_count.get(scope_name, 0)
-
def get_variable(self, name, shape=None, dtype=dtypes.float32,
initializer=None, regularizer=None, reuse=None,
trainable=True, collections=None, caching_device=None,
_VARSTORE_KEY = ("__variable_store",)
-_VARSCOPE_KEY = ("__varscope",)
+_VARSCOPESTORE_KEY = ("__varscope",)
+
+
+class _VariableScopeStore(threading.local):
+ """A thread local store for the current variable scope and scope counts."""
+
+ def __init__(self):
+ super(_VariableScopeStore, self).__init__()
+ self.current_scope = VariableScope(False)
+ self.variable_scopes_count = {}
+
+ def open_variable_scope(self, scope_name):
+ if scope_name in self.variable_scopes_count:
+ self.variable_scopes_count[scope_name] += 1
+ else:
+ self.variable_scopes_count[scope_name] = 1
+
+ def close_variable_subscopes(self, scope_name):
+ for k in self.variable_scopes_count:
+ if not scope_name or k.startswith(scope_name + "/"):
+ self.variable_scopes_count[k] = 0
+
+ def variable_scope_count(self, scope_name):
+ return self.variable_scopes_count.get(scope_name, 0)
+
+
+def get_variable_scope_store():
+ """Returns the variable scope store for current thread."""
+ scope_store = ops.get_collection(_VARSCOPESTORE_KEY)
+
+ if not scope_store:
+ scope_store = _VariableScopeStore()
+ ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store)
+ else:
+ scope_store = scope_store[0]
+
+ return scope_store
@tf_export("get_variable_scope")
def get_variable_scope():
"""Returns the current variable scope."""
- scope = ops.get_collection(_VARSCOPE_KEY)
- if scope: # This collection has at most 1 element, the default scope at [0].
- return scope[0]
- scope = VariableScope(False)
- ops.add_to_collection(_VARSCOPE_KEY, scope)
- return scope
+ return get_variable_scope_store().current_scope
def _get_default_variable_store():
self._dtype = dtype
self._use_resource = use_resource
self._constraint = constraint
- get_variable_scope() # Ensure that a default exists, then get a pointer.
- # Get the reference to the collection as we want to modify it in place.
- self._default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
self._var_store = _get_default_variable_store()
+ self._var_scope_store = get_variable_scope_store()
if isinstance(self._name_or_scope, VariableScope):
self._new_name = self._name_or_scope.name
name_scope = self._name_or_scope._name_scope # pylint: disable=protected-access
a reuse scope, or if reuse is not `None` or `True`.
TypeError: when the types of some arguments are not appropriate.
"""
- self._old = self._default_varscope[0]
+ self._old = self._var_scope_store.current_scope
if isinstance(self._name_or_scope, VariableScope):
- self._var_store.open_variable_scope(self._new_name)
- self._old_subscopes = copy.copy(self._var_store.variable_scopes_count)
+ self._var_scope_store.open_variable_scope(self._new_name)
+ self._old_subscopes = copy.copy(
+ self._var_scope_store.variable_scopes_count)
variable_scope_object = self._cached_variable_scope_object
else:
# Handler for the case when we just prolong current variable scope.
variable_scope_object.set_dtype(self._dtype)
if self._use_resource is not None:
variable_scope_object.set_use_resource(self._use_resource)
- self._var_store.open_variable_scope(self._new_name)
- self._default_varscope[0] = variable_scope_object
+ self._var_scope_store.open_variable_scope(self._new_name)
+ self._var_scope_store.current_scope = variable_scope_object
return variable_scope_object
def __exit__(self, type_arg, value_arg, traceback_arg):
# If jumping out from a non-prolonged scope, restore counts.
if isinstance(self._name_or_scope, VariableScope):
- self._var_store.variable_scopes_count = self._old_subscopes
+ self._var_scope_store.variable_scopes_count = self._old_subscopes
else:
- self._var_store.close_variable_subscopes(self._new_name)
- self._default_varscope[0] = self._old
+ self._var_scope_store.close_variable_subscopes(self._new_name)
+ self._var_scope_store.current_scope = self._old
def _maybe_wrap_custom_getter(custom_getter, old_getter):
def _get_unique_variable_scope(prefix):
"""Get a name with the given prefix unique in the current variable scope."""
- var_store = _get_default_variable_store()
+ var_scope_store = get_variable_scope_store()
current_scope = get_variable_scope()
name = current_scope.name + "/" + prefix if current_scope.name else prefix
- if var_store.variable_scope_count(name) == 0:
+ if var_scope_store.variable_scope_count(name) == 0:
return prefix
idx = 1
- while var_store.variable_scope_count(name + ("_%d" % idx)) > 0:
+ while var_scope_store.variable_scope_count(name + ("_%d" % idx)) > 0:
idx += 1
return prefix + ("_%d" % idx)
graph, ensures that graph is the default graph, and pushes a name scope and a
variable scope.
- If `name_or_scope` is not None, it is used as is. If `scope` is None, then
- `default_name` is used. In that case, if the same name has been previously
- used in the same scope, it will be made unique by appending `_N` to it.
+ If `name_or_scope` is not None, it is used as is. If `name_or_scope` is None,
+ then `default_name` is used. In that case, if the same name has been
+ previously used in the same scope, it will be made unique by appending `_N`
+ to it.
Variable scope allows you to create new variables and to share already created
ones while providing checks to not create or share by accident. For details,
discouraged) to pass False to the reuse argument, yielding undocumented
behaviour slightly different from None. Starting at 1.1.0 passing None and
False as reuse has exactly the same effect.
+
+ A note about using variable scopes in multi-threaded environment: Variable
+ scopes are thread local, so one thread will not see another thread's current
+ scope. Also, when using `default_name`, unique scopes names are also generated
+ only on a per thread basis. If the same name was used within a different
+ thread, that doesn't prevent a new thread from creating the same scope.
+ However, the underlying variable store is shared across threads (within the
+ same graph). As such, if another thread tries to create a new variable with
+ the same name as a variable created by a previous thread, it will fail unless
+ reuse is True.
+
+ Further, each thread starts with an empty variable scope. So if you wish to
+ preserve name prefixes from a scope from the main thread, you should capture
+ the main thread's scope and re-enter it in each thread. For e.g.
+
+ ```
+ main_thread_scope = variable_scope.get_variable_scope()
+
+ # Thread's target function:
+ def thread_target_fn(captured_scope):
+ with variable_scope.variable_scope(captured_scope):
+ # .... regular code for this thread
+
+
+ thread = threading.Thread(target=thread_target_fn, args=(main_thread_scope,))
+ ```
"""
def __init__(self,