Tests for running _Send and _Recv in eager.
authorAlexandre Passos <apassos@google.com>
Tue, 6 Mar 2018 00:09:47 +0000 (16:09 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 00:13:37 +0000 (16:13 -0800)
PiperOrigin-RevId: 187940522

tensorflow/c/eager/runtime.cc
tensorflow/python/eager/core_test.py

index 4bf24fe..9b46cf8 100644 (file)
@@ -302,7 +302,18 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
   params.runner = &runner;
 
   OpKernelContext context(&params);
-  device_->Compute(kernel_.get(), &context);
+
+  if (kernel_->def().op() == "_Recv") {
+    // TODO(apassos) do not special-case _Recv. Currently the GPU device fails
+    // if trying to run _Recv->Compute(), specifically checking for _Recv. To go
+    // around this we call _Recv->ComputeAsync, to mimic graph mode behavior.
+    AsyncOpKernel* async = kernel_->AsAsync();
+    Notification done;
+    device_->ComputeAsync(async, &context, [&done]() { done.Notify(); });
+    done.WaitForNotification();
+  } else {
+    device_->Compute(kernel_.get(), &context);
+  }
   if (!context.status().ok()) return context.status();
 
   output_tensors->clear();
index e418be5..f8f1011 100644 (file)
@@ -546,5 +546,61 @@ class TFETest(test_util.TensorFlowTestCase):
       self.assertIsInstance(t, ops.EagerTensor)
 
 
+class SendRecvTest(test_util.TensorFlowTestCase):
+
+  cpu_device = '/job:localhost/replica:0/task:0/device:CPU:0'
+
+  def _send(self, tensor, tensor_name, to_device):
+    return execute(
+        b'_Send', num_outputs=0, inputs=[tensor],
+        attrs=('T', tensor.dtype.as_datatype_enum,
+               'tensor_name', tensor_name,
+               'send_device', tensor.device,
+               'send_device_incarnation', 0,
+               'recv_device', to_device,
+               'client_terminated', True))
+
+  def _recv(self, dtype, tensor_name, from_device):
+    device_name = context.context().device_name
+    if not device_name:
+      device_name = self.cpu_device
+    return execute(
+        b'_Recv', num_outputs=1, inputs=[],
+        attrs=('tensor_type', dtype.as_datatype_enum,
+               'tensor_name', tensor_name,
+               'send_device', from_device,
+               'send_device_incarnation', 0,
+               'recv_device', device_name,
+               'client_terminated', False))[0]
+
+  def testBasic(self):
+    t0 = constant_op.constant(1.0)
+    t1 = constant_op.constant(2.0)
+    self._send(t0, 't0', self.cpu_device)
+    self._send(t1, 't1', self.cpu_device)
+    self.assertAllEqual(
+        self._recv(dtypes.float32, 't0', self.cpu_device),
+        1.0)
+    self.assertAllEqual(
+        self._recv(dtypes.float32, 't1', self.cpu_device),
+        2.0)
+
+  def testLocalCrossDevice(self):
+    if not context.context().num_gpus():
+      self.skipTest('No GPUs found')
+    gpu_device_name = '/job:localhost/replica:0/task:0/device:GPU:0'
+    with ops.device('GPU:0'):
+      t0 = constant_op.constant(1.0)
+      self._send(t0, 't0', self.cpu_device)
+    self.assertAllEqual(
+        self._recv(dtypes.float32, 't0', gpu_device_name),
+        1.0)
+    self._send(constant_op.constant(2.0), 't1', gpu_device_name)
+    with ops.device('GPU:0'):
+      self.assertAllEqual(
+          self._recv(dtypes.float32, 't1', self.cpu_device),
+          2.0)
+
+
 if __name__ == '__main__':
   test.main()