From 119795f5341737341b526814c6360b5679cd81d3 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Mon, 5 Mar 2018 12:28:07 -0800 Subject: [PATCH] Make variable creator scope thread local (always). PiperOrigin-RevId: 187904394 --- tensorflow/python/framework/ops.py | 18 +++++++++++++----- tensorflow/python/ops/variable_scope.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 0a85b15..47d0bec 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2780,7 +2780,6 @@ class Graph(object): c_api.SetRequireShapeInferenceFns(self._c_graph, False) else: self._scoped_c_graph = None - self._variable_creator_stack = [] # TODO(apassos) remove once the C API is used by default. def _use_c_api_hack(self): @@ -2821,17 +2820,26 @@ class Graph(object): # frozen, and this functionality is still not ready for public visibility. @tf_contextlib.contextmanager def _variable_creator_scope(self, creator): + # This step makes a copy of the existing stack, and it also initializes + # self._thread_local._variable_creator_stack if it doesn't exist yet. old = list(self._variable_creator_stack) - self._variable_creator_stack.append(creator) + self._thread_local._variable_creator_stack.append(creator) try: yield finally: - self._variable_creator_stack = old + self._thread_local._variable_creator_stack = old # Note: this method is private because the API of tf.Graph() is public and # frozen, and this functionality is still not ready for public visibility. - def _get_variable_creator_stack(self): - return list(self._variable_creator_stack) + @property + def _variable_creator_stack(self): + if not hasattr(self._thread_local, "_variable_creator_stack"): + self._thread_local._variable_creator_stack = [] + return list(self._thread_local._variable_creator_stack) + + @_variable_creator_stack.setter + def _variable_creator_stack(self, variable_creator_stack): + self._thread_local._variable_creator_stack = variable_creator_stack def _extract_stack(self): """A lightweight, extensible re-implementation of traceback.extract_stack. diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 81565a6..de4e44f 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -2145,7 +2145,7 @@ def variable(initial_value=None, constraint=None, use_resource=None): previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) - for getter in ops.get_default_graph()._get_variable_creator_stack(): # pylint: disable=protected-access + for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access previous_getter = _make_getter(getter, previous_getter) return previous_getter(initial_value=initial_value, trainable=trainable, -- 2.7.4