From fedca2059d52d4cb753c46d4e65884877b5b4f38 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Fri, 23 Feb 2018 15:35:35 -0800 Subject: [PATCH] Improvement to the eager device placement heuristic. PiperOrigin-RevId: 186833677 --- tensorflow/python/eager/context.py | 3 +-- tensorflow/python/eager/core_test.py | 16 ++++++++++++++-- tensorflow/python/ops/array_ops.py | 5 ++++- tensorflow/python/training/saver.py | 4 ++-- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 07652d3..0e9c21b 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -60,8 +60,7 @@ class _EagerContext(threading.local): def __init__(self): super(_EagerContext, self).__init__() - self.device_spec = pydev.DeviceSpec.from_string( - "/job:localhost/replica:0/task:0/device:CPU:0") + self.device_spec = pydev.DeviceSpec.from_string("") self.device_name = self.device_spec.to_string() self.mode = _default_mode self.scope_name = "" diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index c68e2f4..0e40d8a 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -65,8 +66,7 @@ class TFETest(test_util.TensorFlowTestCase): ctx.summary_writer_resource = 'mock' self.assertEqual('mock', ctx.summary_writer_resource) - self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', - ctx.device_name) + self.assertEqual('', ctx.device_name) self.assertEqual(ctx.device_name, ctx.device_spec.to_string()) with ctx.device('GPU:0'): self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0', @@ -100,6 +100,18 @@ class TFETest(test_util.TensorFlowTestCase): self.assertEqual(len(cpu_stats.node_stats), 1) self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') + def testShouldCopy(self): + if not context.context().num_gpus(): + self.skipTest('No devices other than CPUs found') + with ops.device('gpu:0'): + x = constant_op.constant(1.0) + y = array_ops.identity(x) + # The value we're testing y.device against will depend on what the behavior + # of not explicitly specifying a device in the context is. This behavior is + # subject to change (for example, in the future we may want to use GPUs, if + # available, when no device is explicitly provided) + self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0') + def testContextStackContainsEagerMode(self): # Eager execution has been enabled, and no other context # switch has occurred, so `context_stack` should contain diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index b3020ef..cdfb955 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -134,7 +134,10 @@ def identity(input, name=None): # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input) in_device = input.device # TODO(ashankar): Does 'identity' need to invoke execution callbacks? - if context.context().device_name != in_device: + context_device = context.context().device_name + if not context_device: + context_device = "/job:localhost/replica:0/task:0/device:CPU:0" + if context_device != in_device: return input._copy() # pylint: disable=protected-access return input diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 3888e9b..83e848d 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -196,8 +196,8 @@ class BaseSaverBuilder(object): # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) - return resource_variable_ops.shape_safe_assign_variable_handle( - self.handle_op, self._var_shape, restored_tensor) + return resource_variable_ops.shape_safe_assign_variable_handle( + self.handle_op, self._var_shape, restored_tensor) def __init__(self, write_version=saver_pb2.SaverDef.V2): self._write_version = write_version -- 2.7.4