Tweak the context stack so init_scope works with eager Graphs
authorAllen Lavoie <allenl@google.com>
Mon, 9 Apr 2018 21:52:53 +0000 (14:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 21:54:54 +0000 (14:54 -0700)
Previously breaking out into Graphs created with eager execution enabled would
enter the graph but not re-enable eager execution.

PiperOrigin-RevId: 192192109

tensorflow/python/framework/ops.py
tensorflow/python/framework/ops_test.py

index e3ca5a4..662cda2 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import collections
 import copy
+import functools
 import linecache
 import os
 import re
@@ -5244,14 +5245,35 @@ class _DefaultGraphStack(_DefaultStack):  # pylint: disable=protected-access
   @tf_contextlib.contextmanager
   def get_controller(self, default):
     try:
-      context.context().context_switches.push(default.building_function,
-                                              default.as_default)
+      if context.executing_eagerly():
+        # A Graph alone on the context stack would keep init_scope-wrapped
+        # operations graph building when entered (assuming init_scope is called
+        # in a graph building context). Instead, we push a context which first
+        # enables eager execution and then re-enters the Graph.
+        context.context().context_switches.push(
+            default.building_function,
+            functools.partial(
+                _enter_context_and_graph,
+                context.eager_mode,
+                default.as_default))
+      else:
+        # This Graph is being used from a graph building context. A lack of
+        # context switch implies that the context is graph building.
+        context.context().context_switches.push(default.building_function,
+                                                default.as_default)
       with super(_DefaultGraphStack, self).get_controller(default) as g:
         yield g
     finally:
       context.context().context_switches.pop()
 
 
+@tf_contextlib.contextmanager
+def _enter_context_and_graph(context_fn, graph_fn):
+  """Combines two context managers."""
+  with context_fn(), graph_fn():
+    yield
+
+
 _default_graph_stack = _DefaultGraphStack()
 
 
index 58bead9..c9c1a3d 100644 (file)
@@ -2305,6 +2305,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
           self.assertEqual(ops.get_name_scope(), "inner")
       self.assertEqual(ops.get_name_scope(), "")
 
+  def testEagerGraphContextsExecuteEagerly(self):
+    with context.eager_mode():
+      with ops.Graph().as_default():
+        with context.graph_mode():
+          with ops.init_scope():
+            self.assertTrue(context.executing_eagerly())
+
   def testPreservesNameScopeInEagerExecution(self):
     with context.eager_mode():
       def foo():