self.execution_mode = None
-ContextStackEntry = collections.namedtuple(
- "ContextStackEntry", ["is_building_function", "enter_context_fn"])
+ContextSwitch = collections.namedtuple(
+ "ContextSwitch", ["is_building_function", "enter_context_fn"])
-class ContextStack(threading.local):
+# `_ContextSwitchStack` is a `threading.local` to match the semantics of
+# ``DefaultGraphStack`, which is also a `threading.local`.
+class _ContextSwitchStack(threading.local):
"""A thread-local stack of context switches."""
- def __init__(self):
- super(ContextStack, self).__init__()
+ def __init__(self, eager):
+ super(_ContextSwitchStack, self).__init__()
self.stack = []
+ if eager:
+ # Initialize the stack with a pointer to enter the eager context; this
+ # ensures that the fact that eager execution was enabled is propagated
+ # across threads, since (1) `enable_eager_execution` modifies a
+ # process-level flag (`_default_mode`) and (2) `__init__` is called each
+ # time a threading.local object is used in a separate thread.
+ self.push(is_building_function=False, enter_context_fn=eager_mode)
def push(self, is_building_function, enter_context_fn):
"""Push metadata about a context switch onto the stack.
A context switch can take one of two forms: installing a graph as the
- default graph, or entering the eager context.
+ default graph, or entering the eager context. For each context switch,
+ we record whether or not the entered context is building a function.
Args:
is_building_function: (bool.) Whether the context is building a function.
"""
self.stack.append(
- ContextStackEntry(is_building_function, enter_context_fn))
+ ContextSwitch(is_building_function, enter_context_fn))
def pop(self):
"""Pop the stack."""
self.stack.pop()
-context_stack = ContextStack()
-
-
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
ValueError: If execution_mode is not valid.
"""
self._eager_context = _EagerContext()
+ self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None
self._post_execution_callbacks = []
old_mode = ctx.mode
ctx.mode = mode
if mode == EAGER_MODE:
- context_stack.push(False, eager_mode)
+ # Entering graph mode does not provide us with sufficient information to
+ # record a context switch; graph-based context switches are only logged
+ # when a graph is registered as the default graph.
+ self.context_switches.push(False, eager_mode)
try:
yield
finally:
ctx.mode = old_mode
if mode == EAGER_MODE:
- context_stack.pop()
+ self.context_switches.pop()
def executing_eagerly(self):
"""Returns True if current thread has eager executing enabled."""
run_metadata.ParseFromString(compat.as_bytes(proto_data))
return run_metadata
+ @property
+ def context_switches(self):
+ """Returns a stack of context switches."""
+ return self._context_switches
+
_context = None
_context_lock = threading.Lock()
# available, when no device is explicitly provided)
self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
- def testContextStackContainsEagerMode(self):
- # Eager execution has been enabled, and no other context
- # switch has occurred, so `context_stack` should contain
- # exactly one entry.
- self.assertEqual(len(context.context_stack.stack), 1)
- stack_entry = context.context_stack.stack[0]
+ def testContextSwitchStackContainsEagerMode(self):
+ # Eager execution has been enabled, and no other context switch has
+ # occurred, so `context_switches` should contain exactly one entry.
+ self.assertEqual(len(context.context().context_switches.stack), 1)
+ switch = context.context().context_switches.stack[0]
# The entry should log that eager mode was entered.
- self.assertIs(stack_entry.enter_context_fn, context.eager_mode)
+ self.assertIs(switch.enter_context_fn, context.eager_mode)
# It is not possible to build a graph function when eager execution
# is enabled; the stack entry should reflect this fact.
- self.assertFalse(stack_entry.is_building_function)
+ self.assertFalse(switch.is_building_function)
def testInt32GPU(self):
if not context.context().num_gpus():
from __future__ import division
from __future__ import print_function
+import threading
import numpy as np
from tensorflow.core.protobuf import config_pb2
def testNoOpIsNone(self):
self.assertTrue(control_flow_ops.no_op() is None)
+ def testEagerContextPreservedAcrossThreads(self):
+ def init_fn():
+ self.assertTrue(context.executing_eagerly())
+ with ops.init_scope():
+ self.assertTrue(context.executing_eagerly())
+ context_switches = context.context().context_switches
+ self.assertEqual(len(context_switches.stack), 1)
+ self.assertFalse(context_switches.stack[0].is_building_function)
+ self.assertEqual(context_switches.stack[0].enter_context_fn,
+ context.eager_mode)
+
+ self.assertTrue(context.executing_eagerly())
+ t1 = threading.Thread(target=init_fn)
+ t1.start()
+ t1.join()
+
if __name__ == '__main__':
test.main()
@tf_contextlib.contextmanager
def get_controller(self, default):
try:
- context.context_stack.push(default.building_function, default.as_default)
+ context.context().context_switches.push(default.building_function,
+ default.as_default)
with super(_DefaultGraphStack, self).get_controller(default) as g:
yield g
finally:
- context.context_stack.pop()
+ context.context().context_switches.pop()
_default_graph_stack = _DefaultGraphStack()
graph function. Here, a context is defined as either a graph or an eager
context. Every context switch, i.e., every installation of a graph as
the default graph and every switch into eager mode, is logged in a
- thread-local stack called the `context_stack`; the log entry for a
+ thread-local stack called `context_switches`; the log entry for a
context switch is popped from the stack when the context is exited.
- Entering an `init_scope` is equivalent to crawling up the
- `context_stack`, finding the first context that is not building a graph
- function, and entering it. A caveat is that if graph mode is enabled
- but the default graph stack is empty, then entering an `init_scope`
- will simply install a fresh graph as the default one.
+ Entering an `init_scope` is equivalent to crawling up
+ `context_switches`, finding the first context that is not building a
+ graph function, and entering it. A caveat is that if graph mode is
+ enabled but the default graph stack is empty, then entering an
+ `init_scope` will simply install a fresh graph as the default one.
(3) The gradient tape is paused while the scope is active.
"""
outer_context = default_graph.as_default
else:
# Find a context that is not building a function.
- for stack_entry in reversed(context.context_stack.stack):
+ for stack_entry in reversed(context.context().context_switches.stack):
if not stack_entry.is_building_function:
outer_context = stack_entry.enter_context_fn
break
config=config,
device_policy=device_policy,
execution_mode=execution_mode)
- if context.context_stack.stack:
- raise AssertionError("Invariant violated: The context stack must "
- "be empty when eager execution is enabled.")
- # Log that eager execution has been enabled by pushing an entry onto the
- # context stack; this entry won't ever be popped, as it's impossible to
- # disable eager execution
- context.context_stack.push(False, context.eager_mode)
elif ((config is not None and config is not context._context._config) or
(device_policy is not None and
device_policy is not context._context._device_policy) or