Improvement to the eager device placement heuristic.
authorAlexandre Passos <apassos@google.com>
Fri, 23 Feb 2018 23:35:35 +0000 (15:35 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 23 Feb 2018 23:41:50 +0000 (15:41 -0800)
PiperOrigin-RevId: 186833677

tensorflow/python/eager/context.py
tensorflow/python/eager/core_test.py
tensorflow/python/ops/array_ops.py
tensorflow/python/training/saver.py

index 07652d3..0e9c21b 100644 (file)
@@ -60,8 +60,7 @@ class _EagerContext(threading.local):
 
   def __init__(self):
     super(_EagerContext, self).__init__()
-    self.device_spec = pydev.DeviceSpec.from_string(
-        "/job:localhost/replica:0/task:0/device:CPU:0")
+    self.device_spec = pydev.DeviceSpec.from_string("")
     self.device_name = self.device_spec.to_string()
     self.mode = _default_mode
     self.scope_name = ""
index c68e2f4..0e40d8a 100644 (file)
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import nn_ops
 
 
@@ -65,8 +66,7 @@ class TFETest(test_util.TensorFlowTestCase):
     ctx.summary_writer_resource = 'mock'
     self.assertEqual('mock', ctx.summary_writer_resource)
 
-    self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
-                     ctx.device_name)
+    self.assertEqual('', ctx.device_name)
     self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     with ctx.device('GPU:0'):
       self.assertEqual('/job:localhost/replica:0/task:0/device:GPU:0',
@@ -100,6 +100,18 @@ class TFETest(test_util.TensorFlowTestCase):
     self.assertEqual(len(cpu_stats.node_stats), 1)
     self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
 
+  def testShouldCopy(self):
+    if not context.context().num_gpus():
+      self.skipTest('No devices other than CPUs found')
+    with ops.device('gpu:0'):
+      x = constant_op.constant(1.0)
+    y = array_ops.identity(x)
+    # The value we're testing y.device against will depend on what the behavior
+    # of not explicitly specifying a device in the context is.  This behavior is
+    # subject to change (for example, in the future we may want to use GPUs, if
+    # available, when no device is explicitly provided)
+    self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0')
+
   def testContextStackContainsEagerMode(self):
     # Eager execution has been enabled, and no other context
     # switch has occurred, so `context_stack` should contain
index b3020ef..cdfb955 100644 (file)
@@ -134,7 +134,10 @@ def identity(input, name=None):  # pylint: disable=redefined-builtin
     input = ops.convert_to_tensor(input)
     in_device = input.device
     # TODO(ashankar): Does 'identity' need to invoke execution callbacks?
-    if context.context().device_name != in_device:
+    context_device = context.context().device_name
+    if not context_device:
+      context_device = "/job:localhost/replica:0/task:0/device:CPU:0"
+    if context_device != in_device:
       return input._copy()  # pylint: disable=protected-access
     return input
 
index 3888e9b..83e848d 100644 (file)
@@ -196,8 +196,8 @@ class BaseSaverBuilder(object):
       # Copy the restored tensor to the variable's device.
       with ops.device(self._var_device):
         restored_tensor = array_ops.identity(restored_tensor)
-      return resource_variable_ops.shape_safe_assign_variable_handle(
-          self.handle_op, self._var_shape, restored_tensor)
+        return resource_variable_ops.shape_safe_assign_variable_handle(
+            self.handle_op, self._var_shape, restored_tensor)
 
   def __init__(self, write_version=saver_pb2.SaverDef.V2):
     self._write_version = write_version