Keep track of eager op device for tensor handles. Force-colocates ops using resources...
authorAlexandre Passos <apassos@google.com>
Thu, 1 Mar 2018 17:27:57 +0000 (09:27 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 1 Mar 2018 17:31:56 +0000 (09:31 -0800)
PiperOrigin-RevId: 187488175

tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api_internal.h
tensorflow/python/eager/core_test.py
tensorflow/python/lib/core/py_func.cc

index 29c709b06db9f35a908787a60c2597509e524ecd..252ceab54aef6006326ef20525117b38406004fd 100644 (file)
@@ -159,7 +159,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
   tensorflow::Tensor tensor;
   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
   if (!status->status.ok()) return nullptr;
-  return new TFE_TensorHandle(tensor, nullptr);
+  return new TFE_TensorHandle(tensor, nullptr, nullptr);
 }
 
 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
@@ -222,7 +222,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   // has device type XLA_CPU, and the other CPU.
   const bool both_on_cpu = src_cpu && dst_cpu;
   if (is_same_device || both_on_cpu) {
-    return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
+    dstd = dst_cpu ? nullptr : dstd;
+    return new TFE_TensorHandle(h->t, dstd, dstd);
   }
   tensorflow::Tensor* src = &(h->t);
   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
@@ -241,7 +242,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   }
   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
   if (src->shape().num_elements() == 0) {
-    return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
+    dstd = dst_cpu ? nullptr : dstd;
+    return new TFE_TensorHandle(dst, dstd, dstd);
   }
   tensorflow::DeviceContext* src_device_context = nullptr;
   if (!src_cpu) {
@@ -269,7 +271,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
                                  });
   n.WaitForNotification();
   return (TF_GetCode(status) == TF_OK)
-             ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
+             ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd,
+                                    dst_cpu ? nullptr : dstd)
              : nullptr;
 }
 
@@ -325,6 +328,7 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
   if (!status->status.ok()) return;
   op->inputs.push_back(h->t);
   op->input_devices.push_back(h->d);
+  op->input_op_devices.push_back(h->op_device);
   op->attrs.NumInputs(op->inputs.size());
 }
 
@@ -540,7 +544,8 @@ tensorflow::Status ValidateInputTypeAndPlacement(
       }
       // We are only here if the policy is warn or silent copies, so we should
       // trigger a copy.
-      TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
+      TFE_TensorHandle original{op->inputs[i], op->input_devices[i],
+                                op->device};
       TF_Status* s = TF_NewStatus();
       TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
           &original, ctx, expected_device->name().c_str(), s);
@@ -744,6 +749,7 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
   // via `op_input_to_func_input`, adjust the actual inputs accordingly.
   launch_op->inputs = op->inputs;
   launch_op->input_devices = op->input_devices;
+  launch_op->input_op_devices = op->input_op_devices;
   if (!op_input_to_func_input.empty()) {
     DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
     if (!op->input_devices.empty()) {
@@ -832,9 +838,24 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
     op = xla_launch_op.get();
   }
 #endif  // TENSORFLOW_EAGER_USE_XLA
-
   TFE_Context* ctx = op->ctx;
   tensorflow::Device* device = op->device;
+  // Ensure all resource-touching ops run in the device the resource is,
+  // regardless of anything else that has been specified. This is identical to
+  // the graph mode behavior.
+  for (int i = 0; i < op->inputs.size(); ++i) {
+    if (op->inputs[i].dtype() == tensorflow::DT_RESOURCE &&
+        op->input_op_devices[i] != device) {
+      tensorflow::Device* d = op->input_op_devices[i] == nullptr
+                                  ? ctx->devices()[0]
+                                  : op->input_op_devices[i];
+      VLOG(1) << "Changing device of operation " << op->name << " to "
+              << d->name() << " because input #" << i
+              << " is a resource in this device.";
+      device = d;
+      op->device = d;
+    }
+  }
   if (!ctx->soft_placement && device == nullptr) {
     // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
     device = ctx->devices()[0];
@@ -968,7 +989,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
         (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
       d = nullptr;
     }
-    retvals[i] = new TFE_TensorHandle(outputs[i], d);
+    retvals[i] = new TFE_TensorHandle(outputs[i], d, device);
   }
 }
 
@@ -994,7 +1015,7 @@ void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
 }  // extern "C"
 
 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
-  return new TFE_TensorHandle(t, nullptr);
+  return new TFE_TensorHandle(t, nullptr, nullptr);
 }
 
 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
index 53c21b64cbcd76c2357af8b7973a8359752a6b16..145e4c95cf07373261f81912fe1c35f8db2f9ebd 100644 (file)
@@ -101,8 +101,9 @@ struct TFE_Context {
 };
 
 struct TFE_TensorHandle {
-  TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d)
-      : t(t), d(d) {}
+  TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
+                   tensorflow::Device* op_device)
+      : t(t), d(d), op_device(op_device) {}
 
   tensorflow::Tensor t;
   // TODO(ashankar): d == nullptr iff local CPU
@@ -114,6 +115,10 @@ struct TFE_TensorHandle {
   // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a
   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
   tensorflow::Device* d;
+
+  // Device in which the op producing this tensor was executed. Equals to d for
+  // constant tensors.
+  tensorflow::Device* op_device;
 };
 
 struct TFE_Op {
@@ -130,6 +135,7 @@ struct TFE_Op {
   const tensorflow::AttrTypeMap* attr_types;
   std::vector<tensorflow::Tensor> inputs;
   std::vector<tensorflow::Device*> input_devices;
+  std::vector<tensorflow::Device*> input_op_devices;
   tensorflow::Device* device;
   bool use_xla = false;
 };
index 0e40d8a5c0a582ab27d95735dd917e2a5daabe09..e418be5fae4da46615f7b1467252ae6b26b9e6a3 100644 (file)
@@ -34,7 +34,9 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import resource_variable_ops
 
 
 def execute(op_name, num_outputs, inputs, attrs=None):
@@ -181,6 +183,18 @@ class TFETest(test_util.TensorFlowTestCase):
         attrs=('T', x.dtype.as_datatype_enum))[0].cpu().numpy()
     self.assertEqual(3, result)
 
+  def testResourceTensorPlacement(self):
+    if not context.context().num_gpus():
+      self.skipTest('No GPUs found')
+
+    with context.device('gpu:0'):
+      v = resource_variable_ops.ResourceVariable(1.0)
+    with context.device('cpu:0'):
+      # Check that even though we specified the cpu device we'll run the read op
+      # in the device where the handle is.
+      self.assertAllEqual(
+          gen_resource_variable_ops.read_variable_op(v.handle, v.dtype), 1.0)
+
   def testCopyBetweenDevices(self):
     if not context.context().num_gpus():
       self.skipTest('No GPUs found')
index e0422ef80add42307268be2743e668eb8c8acb68..343415b2645e00003e51fad18cbb1ec602db472d 100644 (file)
@@ -79,10 +79,11 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
     const Tensor& t = call->ins[i];
     if (call->eager) {
       if (call->gpu) {
-        arg = EagerTensorFromHandle(new TFE_TensorHandle(t, call->device));
+        arg = EagerTensorFromHandle(
+            new TFE_TensorHandle(t, call->device, call->device));
       } else {
         // TFE_TensorHandle assumes that CPU is identified by `nullptr`.
-        arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr));
+        arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr));
       }
       if (arg == nullptr) {
         return errors::Internal("Unable to procure EagerTensor from Tensor.");