From 13e7f92d120de8f6f548493eb49b74810888ffd4 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Tue, 22 May 2018 10:26:00 -0700 Subject: [PATCH] Make init_scope preserve the inner device stack when lifting into a graph. Eager execution doesn't implement device stacks and in particular it doesn't support device functions (which determine the device on a per-op basis), so in general it's not possible to do the same when lifting into the eager context. PiperOrigin-RevId: 197583446 --- tensorflow/python/framework/ops.py | 21 ++++++++++++++++++--- tensorflow/python/framework/ops_test.py | 15 +++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 7b3acc4..80140e4 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5347,6 +5347,7 @@ def init_scope(): # Names that end with trailing slashes are treated by `name_scope` as # absolute. scope = scope + '/' + inner_device_stack = default_graph._device_function_stack # pylint: disable=protected-access outer_context = None if not _default_graph_stack.stack: @@ -5375,9 +5376,23 @@ def init_scope(): raise RuntimeError("All graphs are building functions, and no " "eager context was previously active.") - with outer_context(), name_scope(scope), control_dependencies( - None), tape.stop_recording(): - yield + outer_graph = None + outer_device_stack = None + try: + with outer_context(), name_scope(scope), control_dependencies( + None), tape.stop_recording(): + if not context.executing_eagerly(): + # The device stack is preserved when lifting into a graph. Eager + # execution doesn't implement device stacks and in particular it + # doesn't support device functions, so in general it's not possible + # to do the same when lifting into the eager context. + outer_graph = get_default_graph() + outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access + outer_graph._device_function_stack = inner_device_stack # pylint: disable=protected-access + yield + finally: + if outer_graph is not None: + outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access @tf_export("enable_eager_execution") diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index a896601..87317db 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -2033,6 +2033,21 @@ class InitScopeTest(test_util.TensorFlowTestCase): self.assertEqual(len(g1.get_operations()), 0) self.assertEqual(len(g0.get_operations()), 1) + def testPreservesDevices(self): + g0 = ops.Graph() + with g0.as_default(), ops.device("CPU:0"): + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + with g1.as_default(), ops.device("GPU:0"): + with ops.init_scope(): + # init_scope should preserve device set under `g1`. + on_gpu = constant_op.constant(1.0) + self.assertEqual(on_gpu.device, "/device:GPU:0") + still_on_gpu = constant_op.constant(1.0) + self.assertEqual(still_on_gpu.device, "/device:GPU:0") + on_cpu = constant_op.constant(1.0) + self.assertEqual(on_cpu.device, "/device:CPU:0") + def testComposes(self): g0 = ops.Graph() g1 = ops.Graph() -- 2.7.4