Make variable creator scope thread local (always).
authorPriya Gupta <priyag@google.com>
Mon, 5 Mar 2018 20:28:07 +0000 (12:28 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 20:32:32 +0000 (12:32 -0800)
PiperOrigin-RevId: 187904394

tensorflow/python/framework/ops.py
tensorflow/python/ops/variable_scope.py

index 0a85b15..47d0bec 100644 (file)
@@ -2780,7 +2780,6 @@ class Graph(object):
       c_api.SetRequireShapeInferenceFns(self._c_graph, False)
     else:
       self._scoped_c_graph = None
-    self._variable_creator_stack = []
 
   # TODO(apassos) remove once the C API is used by default.
   def _use_c_api_hack(self):
@@ -2821,17 +2820,26 @@ class Graph(object):
   # frozen, and this functionality is still not ready for public visibility.
   @tf_contextlib.contextmanager
   def _variable_creator_scope(self, creator):
+    # This step makes a copy of the existing stack, and it also initializes
+    # self._thread_local._variable_creator_stack if it doesn't exist yet.
     old = list(self._variable_creator_stack)
-    self._variable_creator_stack.append(creator)
+    self._thread_local._variable_creator_stack.append(creator)
     try:
       yield
     finally:
-      self._variable_creator_stack = old
+      self._thread_local._variable_creator_stack = old
 
   # Note: this method is private because the API of tf.Graph() is public and
   # frozen, and this functionality is still not ready for public visibility.
-  def _get_variable_creator_stack(self):
-    return list(self._variable_creator_stack)
+  @property
+  def _variable_creator_stack(self):
+    if not hasattr(self._thread_local, "_variable_creator_stack"):
+      self._thread_local._variable_creator_stack = []
+    return list(self._thread_local._variable_creator_stack)
+
+  @_variable_creator_stack.setter
+  def _variable_creator_stack(self, variable_creator_stack):
+    self._thread_local._variable_creator_stack = variable_creator_stack
 
   def _extract_stack(self):
     """A lightweight, extensible re-implementation of traceback.extract_stack.
index 81565a6..de4e44f 100644 (file)
@@ -2145,7 +2145,7 @@ def variable(initial_value=None,
              constraint=None,
              use_resource=None):
   previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
-  for getter in ops.get_default_graph()._get_variable_creator_stack():  # pylint: disable=protected-access
+  for getter in ops.get_default_graph()._variable_creator_stack:  # pylint: disable=protected-access
     previous_getter = _make_getter(getter, previous_getter)
   return previous_getter(initial_value=initial_value,
                          trainable=trainable,