From 265099d262225a4b54619ee591d261e8146051e4 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 9 Apr 2018 14:52:53 -0700 Subject: [PATCH] Tweak the context stack so init_scope works with eager Graphs Previously breaking out into Graphs created with eager execution enabled would enter the graph but not re-enable eager execution. PiperOrigin-RevId: 192192109 --- tensorflow/python/framework/ops.py | 26 ++++++++++++++++++++++++-- tensorflow/python/framework/ops_test.py | 7 +++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e3ca5a4..662cda2 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import copy +import functools import linecache import os import re @@ -5244,14 +5245,35 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access @tf_contextlib.contextmanager def get_controller(self, default): try: - context.context().context_switches.push(default.building_function, - default.as_default) + if context.executing_eagerly(): + # A Graph alone on the context stack would keep init_scope-wrapped + # operations graph building when entered (assuming init_scope is called + # in a graph building context). Instead, we push a context which first + # enables eager execution and then re-enters the Graph. + context.context().context_switches.push( + default.building_function, + functools.partial( + _enter_context_and_graph, + context.eager_mode, + default.as_default)) + else: + # This Graph is being used from a graph building context. A lack of + # context switch implies that the context is graph building. + context.context().context_switches.push(default.building_function, + default.as_default) with super(_DefaultGraphStack, self).get_controller(default) as g: yield g finally: context.context().context_switches.pop() +@tf_contextlib.contextmanager +def _enter_context_and_graph(context_fn, graph_fn): + """Combines two context managers.""" + with context_fn(), graph_fn(): + yield + + _default_graph_stack = _DefaultGraphStack() diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 58bead9..c9c1a3d 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2305,6 +2305,13 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(ops.get_name_scope(), "inner") self.assertEqual(ops.get_name_scope(), "") + def testEagerGraphContextsExecuteEagerly(self): + with context.eager_mode(): + with ops.Graph().as_default(): + with context.graph_mode(): + with ops.init_scope(): + self.assertTrue(context.executing_eagerly()) + def testPreservesNameScopeInEagerExecution(self): with context.eager_mode(): def foo(): -- 2.7.4