#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#endif // GOOGLE_CUDA
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
- [opts, response, done, src_dev](const Status& status,
- const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args,
- const Tensor& val, const bool is_dead) {
+ [opts, response, done, src_dev, request](
+ const Status& status, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& val,
+ const bool is_dead) {
opts->ClearCancelCallback();
if (status.ok()) {
// DMA can only be used for Tensors that do not fall into
{
// Non-DMA cases.
if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
-#if GOOGLE_CUDA
- const DeviceContext* send_dev_context = send_args.device_context;
+ DeviceContext* send_dev_context = send_args.device_context;
AllocatorAttributes alloc_attrs;
alloc_attrs.set_gpu_compatible(true);
alloc_attrs.set_on_host(true);
CHECK(send_dev_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
- // "val" is on a GPU. Uses GPUUtil to fill the copy on host.
+ // "val" is on an accelerator device. Uses the device_context to
+ // fill the copy on host.
StatusCallback copy_ready = [response, done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
delete copy;
};
- GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
- copy_ready);
-#else
- done(errors::Internal("No GPU device in process"));
-#endif // GOOGLE_CUDA
+ send_dev_context->CopyDeviceTensorToCPU(
+ &val, request->rendezvous_key(), src_dev, copy, copy_ready);
} else {
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
done(Status::OK());