TFE: Modify initialization of `ContextStack` to ensure eager context is kept.
authorAkshay Agrawal <akshayka@google.com>
Thu, 15 Mar 2018 17:26:13 +0000 (10:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Mar 2018 17:30:05 +0000 (10:30 -0700)
When eager execution is enabled in the main thread, the fact that it was
enabled is propagated to subsequently created threads. This
change ...

(1) ensures that the fact that eager was enabled is also propagated to the `ContextStack`, which is renamed to `_ContextSwitchStack`, for clarity;
(2) adds a `_ContextSwitchStack` object to `Context` as a member, removing the global `context_stack`.

PiperOrigin-RevId: 189206207

tensorflow/python/eager/context.py
tensorflow/python/eager/core_test.py
tensorflow/python/eager/ops_test.py
tensorflow/python/framework/ops.py

index 7953d10..6c9a147 100644 (file)
@@ -94,22 +94,32 @@ class _EagerContext(threading.local):
     self.execution_mode = None
 
 
-ContextStackEntry = collections.namedtuple(
-    "ContextStackEntry", ["is_building_function", "enter_context_fn"])
+ContextSwitch = collections.namedtuple(
+    "ContextSwitch", ["is_building_function", "enter_context_fn"])
 
 
-class ContextStack(threading.local):
+# `_ContextSwitchStack` is a `threading.local` to match the semantics of
+# ``DefaultGraphStack`, which is also a `threading.local`.
+class _ContextSwitchStack(threading.local):
   """A thread-local stack of context switches."""
 
-  def __init__(self):
-    super(ContextStack, self).__init__()
+  def __init__(self, eager):
+    super(_ContextSwitchStack, self).__init__()
     self.stack = []
+    if eager:
+      # Initialize the stack with a pointer to enter the eager context; this
+      # ensures that the fact that eager execution was enabled is propagated
+      # across threads, since (1) `enable_eager_execution` modifies a
+      # process-level flag (`_default_mode`) and (2) `__init__` is called each
+      # time a threading.local object is used in a separate thread.
+      self.push(is_building_function=False, enter_context_fn=eager_mode)
 
   def push(self, is_building_function, enter_context_fn):
     """Push metadata about a context switch onto the stack.
 
     A context switch can take one of two forms: installing a graph as the
-    default graph, or entering the eager context.
+    default graph, or entering the eager context. For each context switch,
+    we record whether or not the entered context is building a function.
 
     Args:
       is_building_function: (bool.) Whether the context is building a function.
@@ -118,7 +128,7 @@ class ContextStack(threading.local):
     """
 
     self.stack.append(
-        ContextStackEntry(is_building_function, enter_context_fn))
+        ContextSwitch(is_building_function, enter_context_fn))
 
   def pop(self):
     """Pop the stack."""
@@ -126,9 +136,6 @@ class ContextStack(threading.local):
     self.stack.pop()
 
 
-context_stack = ContextStack()
-
-
 # TODO(agarwal): rename to EagerContext / EagerRuntime ?
 # TODO(agarwal): consider keeping the corresponding Graph here.
 class Context(object):
@@ -171,6 +178,7 @@ class Context(object):
      ValueError: If execution_mode is not valid.
     """
     self._eager_context = _EagerContext()
+    self._context_switches = _ContextSwitchStack(self.executing_eagerly())
     self._context_handle = None
     self._context_devices = None
     self._post_execution_callbacks = []
@@ -283,13 +291,16 @@ class Context(object):
     old_mode = ctx.mode
     ctx.mode = mode
     if mode == EAGER_MODE:
-      context_stack.push(False, eager_mode)
+      # Entering graph mode does not provide us with sufficient information to
+      # record a context switch; graph-based context switches are only logged
+      # when a graph is registered as the default graph.
+      self.context_switches.push(False, eager_mode)
     try:
       yield
     finally:
       ctx.mode = old_mode
       if mode == EAGER_MODE:
-        context_stack.pop()
+        self.context_switches.pop()
 
   def executing_eagerly(self):
     """Returns True if current thread has eager executing enabled."""
@@ -545,6 +556,11 @@ class Context(object):
     run_metadata.ParseFromString(compat.as_bytes(proto_data))
     return run_metadata
 
+  @property
+  def context_switches(self):
+    """Returns a stack of context switches."""
+    return self._context_switches
+
 _context = None
 _context_lock = threading.Lock()
 
index 6dfd8d1..6ebf5b2 100644 (file)
@@ -123,19 +123,18 @@ class TFETest(test_util.TensorFlowTestCase):
     # available, when no device is explicitly provided)
     self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
 
-  def testContextStackContainsEagerMode(self):
-    # Eager execution has been enabled, and no other context
-    # switch has occurred, so `context_stack` should contain
-    # exactly one entry.
-    self.assertEqual(len(context.context_stack.stack), 1)
-    stack_entry = context.context_stack.stack[0]
+  def testContextSwitchStackContainsEagerMode(self):
+    # Eager execution has been enabled, and no other context switch has
+    # occurred, so `context_switches` should contain exactly one entry.
+    self.assertEqual(len(context.context().context_switches.stack), 1)
+    switch = context.context().context_switches.stack[0]
 
     # The entry should log that eager mode was entered.
-    self.assertIs(stack_entry.enter_context_fn, context.eager_mode)
+    self.assertIs(switch.enter_context_fn, context.eager_mode)
 
     # It is not possible to build a graph function when eager execution
     # is enabled; the stack entry should reflect this fact.
-    self.assertFalse(stack_entry.is_building_function)
+    self.assertFalse(switch.is_building_function)
 
   def testInt32GPU(self):
     if not context.context().num_gpus():
index f70c754..fc76ede 100644 (file)
@@ -17,6 +17,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import threading
 import numpy as np
 
 from tensorflow.core.protobuf import config_pb2
@@ -376,6 +377,22 @@ class OpsTest(test_util.TensorFlowTestCase):
   def testNoOpIsNone(self):
     self.assertTrue(control_flow_ops.no_op() is None)
 
+  def testEagerContextPreservedAcrossThreads(self):
+    def init_fn():
+      self.assertTrue(context.executing_eagerly())
+      with ops.init_scope():
+        self.assertTrue(context.executing_eagerly())
+        context_switches = context.context().context_switches
+        self.assertEqual(len(context_switches.stack), 1)
+        self.assertFalse(context_switches.stack[0].is_building_function)
+        self.assertEqual(context_switches.stack[0].enter_context_fn,
+                         context.eager_mode)
+
+    self.assertTrue(context.executing_eagerly())
+    t1 = threading.Thread(target=init_fn)
+    t1.start()
+    t1.join()
+
 
 if __name__ == '__main__':
   test.main()
index b2f4377..01a0e03 100644 (file)
@@ -5095,11 +5095,12 @@ class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
   @tf_contextlib.contextmanager
   def get_controller(self, default):
     try:
-      context.context_stack.push(default.building_function, default.as_default)
+      context.context().context_switches.push(default.building_function,
+                                              default.as_default)
       with super(_DefaultGraphStack, self).get_controller(default) as g:
         yield g
     finally:
-      context.context_stack.pop()
+      context.context().context_switches.pop()
 
 
 _default_graph_stack = _DefaultGraphStack()
@@ -5125,13 +5126,13 @@ def init_scope():
         graph function. Here, a context is defined as either a graph or an eager
         context. Every context switch, i.e., every installation of a graph as
         the default graph and every switch into eager mode, is logged in a
-        thread-local stack called the `context_stack`; the log entry for a
+        thread-local stack called `context_switches`; the log entry for a
         context switch is popped from the stack when the context is exited.
-        Entering an `init_scope` is equivalent to crawling up the
-        `context_stack`, finding the first context that is not building a graph
-        function, and entering it. A caveat is that if graph mode is enabled
-        but the default graph stack is empty, then entering an `init_scope`
-        will simply install a fresh graph as the default one.
+        Entering an `init_scope` is equivalent to crawling up
+        `context_switches`, finding the first context that is not building a
+        graph function, and entering it. A caveat is that if graph mode is
+        enabled but the default graph stack is empty, then entering an
+        `init_scope` will simply install a fresh graph as the default one.
 
     (3) The gradient tape is paused while the scope is active.
   """
@@ -5161,7 +5162,7 @@ def init_scope():
       outer_context = default_graph.as_default
     else:
       # Find a context that is not building a function.
-      for stack_entry in reversed(context.context_stack.stack):
+      for stack_entry in reversed(context.context().context_switches.stack):
         if not stack_entry.is_building_function:
           outer_context = stack_entry.enter_context_fn
           break
@@ -5278,13 +5279,6 @@ def enable_eager_execution(config=None, device_policy=None,
         config=config,
         device_policy=device_policy,
         execution_mode=execution_mode)
-    if context.context_stack.stack:
-      raise AssertionError("Invariant violated: The context stack must "
-                           "be empty when eager execution is enabled.")
-    # Log that eager execution has been enabled by pushing an entry onto the
-    # context stack; this entry won't ever be popped, as it's impossible to
-    # disable eager execution
-    context.context_stack.push(False, context.eager_mode)
   elif ((config is not None and config is not context._context._config) or
         (device_policy is not None and
          device_policy is not context._context._device_policy) or