Change RecvTensor RPC implementation to use DeviceContext::CopyDeviceTensorToCPU...
authorPeter Hawkins <phawkins@google.com>
Fri, 4 May 2018 17:18:46 +0000 (10:18 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 17:54:38 +0000 (10:54 -0700)
PiperOrigin-RevId: 195433287

tensorflow/core/distributed_runtime/rpc/BUILD
tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc

index e973a22..c2719f5 100644 (file)
@@ -169,7 +169,6 @@ tf_cuda_library(
         ":grpc_worker_service_impl",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
-        "//tensorflow/core:gpu_runtime",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:worker_proto_cc",
index bbf7391..26fad1f 100644 (file)
@@ -23,9 +23,6 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#endif  // GOOGLE_CUDA
 #include "tensorflow/core/common_runtime/local_device.h"
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -439,10 +436,10 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
   opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
   env_->rendezvous_mgr->RecvLocalAsync(
       step_id, parsed,
-      [opts, response, done, src_dev](const Status& status,
-                                      const Rendezvous::Args& send_args,
-                                      const Rendezvous::Args& recv_args,
-                                      const Tensor& val, const bool is_dead) {
+      [opts, response, done, src_dev, request](
+          const Status& status, const Rendezvous::Args& send_args,
+          const Rendezvous::Args& recv_args, const Tensor& val,
+          const bool is_dead) {
         opts->ClearCancelCallback();
         if (status.ok()) {
           // DMA can only be used for Tensors that do not fall into
@@ -455,8 +452,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
           {
             // Non-DMA cases.
             if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
-#if GOOGLE_CUDA
-              const DeviceContext* send_dev_context = send_args.device_context;
+              DeviceContext* send_dev_context = send_args.device_context;
               AllocatorAttributes alloc_attrs;
               alloc_attrs.set_gpu_compatible(true);
               alloc_attrs.set_on_host(true);
@@ -465,7 +461,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
               CHECK(send_dev_context)
                   << "send dev name: " << src_dev->name()
                   << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
-              // "val" is on a GPU. Uses GPUUtil to fill the copy on host.
+              // "val" is on an accelerator device. Uses the device_context to
+              // fill the copy on host.
               StatusCallback copy_ready = [response, done, copy,
                                            is_dead](const Status& s) {
                 // The value is now ready to be returned on the wire.
@@ -474,11 +471,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
                 delete copy;
               };
 
-              GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
-                                          copy_ready);
-#else
-              done(errors::Internal("No GPU device in process"));
-#endif  // GOOGLE_CUDA
+              send_dev_context->CopyDeviceTensorToCPU(
+                  &val, request->rendezvous_key(), src_dev, copy, copy_ready);
             } else {
               grpc::EncodeTensorToByteBuffer(is_dead, val, response);
               done(Status::OK());