From 2acc5b6c465832fc8c1fba2454d3dfd8f3aa2eb5 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Thu, 15 Mar 2018 10:26:13 -0700 Subject: [PATCH] TFE: Modify initialization of `ContextStack` to ensure eager context is kept. When eager execution is enabled in the main thread, the fact that it was enabled is propagated to subsequently created threads. This change ... (1) ensures that the fact that eager was enabled is also propagated to the `ContextStack`, which is renamed to `_ContextSwitchStack`, for clarity; (2) adds a `_ContextSwitchStack` object to `Context` as a member, removing the global `context_stack`. PiperOrigin-RevId: 189206207 --- tensorflow/python/eager/context.py | 40 +++++++++++++++++++++++++----------- tensorflow/python/eager/core_test.py | 15 +++++++------- tensorflow/python/eager/ops_test.py | 17 +++++++++++++++ tensorflow/python/framework/ops.py | 26 +++++++++-------------- 4 files changed, 62 insertions(+), 36 deletions(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 7953d10..6c9a147 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -94,22 +94,32 @@ class _EagerContext(threading.local): self.execution_mode = None -ContextStackEntry = collections.namedtuple( - "ContextStackEntry", ["is_building_function", "enter_context_fn"]) +ContextSwitch = collections.namedtuple( + "ContextSwitch", ["is_building_function", "enter_context_fn"]) -class ContextStack(threading.local): +# `_ContextSwitchStack` is a `threading.local` to match the semantics of +# ``DefaultGraphStack`, which is also a `threading.local`. +class _ContextSwitchStack(threading.local): """A thread-local stack of context switches.""" - def __init__(self): - super(ContextStack, self).__init__() + def __init__(self, eager): + super(_ContextSwitchStack, self).__init__() self.stack = [] + if eager: + # Initialize the stack with a pointer to enter the eager context; this + # ensures that the fact that eager execution was enabled is propagated + # across threads, since (1) `enable_eager_execution` modifies a + # process-level flag (`_default_mode`) and (2) `__init__` is called each + # time a threading.local object is used in a separate thread. + self.push(is_building_function=False, enter_context_fn=eager_mode) def push(self, is_building_function, enter_context_fn): """Push metadata about a context switch onto the stack. A context switch can take one of two forms: installing a graph as the - default graph, or entering the eager context. + default graph, or entering the eager context. For each context switch, + we record whether or not the entered context is building a function. Args: is_building_function: (bool.) Whether the context is building a function. @@ -118,7 +128,7 @@ class ContextStack(threading.local): """ self.stack.append( - ContextStackEntry(is_building_function, enter_context_fn)) + ContextSwitch(is_building_function, enter_context_fn)) def pop(self): """Pop the stack.""" @@ -126,9 +136,6 @@ class ContextStack(threading.local): self.stack.pop() -context_stack = ContextStack() - - # TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): consider keeping the corresponding Graph here. class Context(object): @@ -171,6 +178,7 @@ class Context(object): ValueError: If execution_mode is not valid. """ self._eager_context = _EagerContext() + self._context_switches = _ContextSwitchStack(self.executing_eagerly()) self._context_handle = None self._context_devices = None self._post_execution_callbacks = [] @@ -283,13 +291,16 @@ class Context(object): old_mode = ctx.mode ctx.mode = mode if mode == EAGER_MODE: - context_stack.push(False, eager_mode) + # Entering graph mode does not provide us with sufficient information to + # record a context switch; graph-based context switches are only logged + # when a graph is registered as the default graph. + self.context_switches.push(False, eager_mode) try: yield finally: ctx.mode = old_mode if mode == EAGER_MODE: - context_stack.pop() + self.context_switches.pop() def executing_eagerly(self): """Returns True if current thread has eager executing enabled.""" @@ -545,6 +556,11 @@ class Context(object): run_metadata.ParseFromString(compat.as_bytes(proto_data)) return run_metadata + @property + def context_switches(self): + """Returns a stack of context switches.""" + return self._context_switches + _context = None _context_lock = threading.Lock() diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 6dfd8d1..6ebf5b2 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -123,19 +123,18 @@ class TFETest(test_util.TensorFlowTestCase): # available, when no device is explicitly provided) self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0') - def testContextStackContainsEagerMode(self): - # Eager execution has been enabled, and no other context - # switch has occurred, so `context_stack` should contain - # exactly one entry. - self.assertEqual(len(context.context_stack.stack), 1) - stack_entry = context.context_stack.stack[0] + def testContextSwitchStackContainsEagerMode(self): + # Eager execution has been enabled, and no other context switch has + # occurred, so `context_switches` should contain exactly one entry. + self.assertEqual(len(context.context().context_switches.stack), 1) + switch = context.context().context_switches.stack[0] # The entry should log that eager mode was entered. - self.assertIs(stack_entry.enter_context_fn, context.eager_mode) + self.assertIs(switch.enter_context_fn, context.eager_mode) # It is not possible to build a graph function when eager execution # is enabled; the stack entry should reflect this fact. - self.assertFalse(stack_entry.is_building_function) + self.assertFalse(switch.is_building_function) def testInt32GPU(self): if not context.context().num_gpus(): diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index f70c754..fc76ede 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading import numpy as np from tensorflow.core.protobuf import config_pb2 @@ -376,6 +377,22 @@ class OpsTest(test_util.TensorFlowTestCase): def testNoOpIsNone(self): self.assertTrue(control_flow_ops.no_op() is None) + def testEagerContextPreservedAcrossThreads(self): + def init_fn(): + self.assertTrue(context.executing_eagerly()) + with ops.init_scope(): + self.assertTrue(context.executing_eagerly()) + context_switches = context.context().context_switches + self.assertEqual(len(context_switches.stack), 1) + self.assertFalse(context_switches.stack[0].is_building_function) + self.assertEqual(context_switches.stack[0].enter_context_fn, + context.eager_mode) + + self.assertTrue(context.executing_eagerly()) + t1 = threading.Thread(target=init_fn) + t1.start() + t1.join() + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b2f4377..01a0e03 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5095,11 +5095,12 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): try: - context.context_stack.push(default.building_function, default.as_default) + context.context().context_switches.push(default.building_function, + default.as_default) with super(_DefaultGraphStack, self).get_controller(default) as g: yield g finally: - context.context_stack.pop() + context.context().context_switches.pop() _default_graph_stack = _DefaultGraphStack() @@ -5125,13 +5126,13 @@ def init_scope(): graph function. Here, a context is defined as either a graph or an eager context. Every context switch, i.e., every installation of a graph as the default graph and every switch into eager mode, is logged in a - thread-local stack called the `context_stack`; the log entry for a + thread-local stack called `context_switches`; the log entry for a context switch is popped from the stack when the context is exited. - Entering an `init_scope` is equivalent to crawling up the - `context_stack`, finding the first context that is not building a graph - function, and entering it. A caveat is that if graph mode is enabled - but the default graph stack is empty, then entering an `init_scope` - will simply install a fresh graph as the default one. + Entering an `init_scope` is equivalent to crawling up + `context_switches`, finding the first context that is not building a + graph function, and entering it. A caveat is that if graph mode is + enabled but the default graph stack is empty, then entering an + `init_scope` will simply install a fresh graph as the default one. (3) The gradient tape is paused while the scope is active. """ @@ -5161,7 +5162,7 @@ def init_scope(): outer_context = default_graph.as_default else: # Find a context that is not building a function. - for stack_entry in reversed(context.context_stack.stack): + for stack_entry in reversed(context.context().context_switches.stack): if not stack_entry.is_building_function: outer_context = stack_entry.enter_context_fn break @@ -5278,13 +5279,6 @@ def enable_eager_execution(config=None, device_policy=None, config=config, device_policy=device_policy, execution_mode=execution_mode) - if context.context_stack.stack: - raise AssertionError("Invariant violated: The context stack must " - "be empty when eager execution is enabled.") - # Log that eager execution has been enabled by pushing an entry onto the - # context stack; this entry won't ever be popped, as it's impossible to - # disable eager execution - context.context_stack.push(False, context.eager_mode) elif ((config is not None and config is not context._context._config) or (device_policy is not None and device_policy is not context._context._device_policy) or -- 2.7.4