From: A. Unique TensorFlower Date: Fri, 6 Apr 2018 10:23:54 +0000 (-0700) Subject: internal change X-Git-Tag: tflite-v0.1.7~16^2^2~116 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=58df8c97a7dc2ed2159e8137312fa29c0d7bcf67;p=platform%2Fupstream%2Ftensorflow.git internal change PiperOrigin-RevId: 191869400 --- diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 24aa203..a492fc6 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -204,14 +204,14 @@ cc_library( ":common", ":xla_compilation_cache", ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 2d6511a..f48941f 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -155,6 +155,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); options.device_allocator = xla_allocator; + // TODO(b/77671268): We don't set variable_representation_shape_fn here. This + // is restricted to Variables, but we need something like this to apply to + // normal Tensors too. const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -179,8 +182,10 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { run_options.set_stream(stream); run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(ctx->step_id()); Env* env = Env::Default(); auto start_time = env->NowMicros(); + auto run_result = executable->Run(launch_context.arguments(), run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6a57831..43eb164 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" @@ -53,8 +54,33 @@ XlaTransferManager::XlaTransferManager(se::Stream* stream, bool transfer_as_literal) : stream_(stream), client_(client), + transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal) {} +Status XlaTransferManager::TransferLiteralToDevice( + const Tensor& host_tensor, Tensor* device_tensor) const { + xla::Literal literal; + TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); + VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, + shaped_buffer); +} + +Status XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor) const { + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + transfer_manager_->TransferLiteralFromDevice( + stream_->parent(), shaped_buffer)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString(); + return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); +} + void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -86,9 +112,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = xla::Unimplemented( - "XlaTransferManager::CopyCPUTensorToDevice not implemented for " - "literals"); + status = TransferLiteralToDevice(*cpu_tensor, device_tensor); } else { stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. @@ -129,9 +153,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status status; if (transfer_as_literal_) { - status = xla::Unimplemented( - "XlaTransferManager::CopyDeviceTensorToCPU not implemented for " - "literals"); + status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); } else { stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index a8ad511..ad914a1 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -57,11 +57,18 @@ class XlaTransferManager { perftools::gputools::Stream* stream() const { return stream_; } private: + Status TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor) const; + Status TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor) const; + // Stream obtained from a Device, used to transfer tensors between // CPU and device. perftools::gputools::Stream* stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; + // Transfer manager, for marshalling data to and from the device. + xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. bool transfer_as_literal_; }; diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 354be1e..50b0061 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -16,12 +16,14 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -165,6 +167,8 @@ void XlaComputationLaunchContext::PopulateOutputs( // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output->on_device_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); @@ -179,6 +183,10 @@ void XlaComputationLaunchContext::PopulateOutputs( const size_t total_bytes = const_tensor.TotalBytes(); if (stream && total_bytes > 0) { // Copy host -> device. (Empty tensors don't have backing buffers.) + // Manually allocate memory using an XlaTensorBuffer so we can allocate + // as much memory as the device requires (as given by + // GetByteSizeRequirement). This avoids XlaTransferManager having to + // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; OP_REQUIRES_OK( @@ -189,15 +197,23 @@ void XlaComputationLaunchContext::PopulateOutputs( client_, stream->parent()->device_ordinal())); } - const void* src_ptr = DMAHelper::base(&const_tensor); - gpu::DeviceMemoryBase dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*output_tensor); - // Memcpying asynchronously is safe for the GPU, but the CPU uses a - // shared allocator so hold a reference to the copied-to buffer until - // complete. - TensorReference ref(*output_tensor); - stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes); - stream->ThenDoHostCallback([ref] { ref.Unref(); }); + Device* device = dynamic_cast(ctx->device()); + OP_REQUIRES(ctx, device != nullptr, + errors::Internal("DeviceBase was not a Device.")); + ctx->op_device_context()->CopyCPUTensorToDevice( + &const_tensor, device, output_tensor, + [&](Status status) { TF_CHECK_OK(status); }); + + if (device->device_type() == DEVICE_GPU) { + // The GPUDeviceContext enqueues the host->device transfer in a + // separate stream from the main compute stream. We must ensure the + // compute stream is synchronized with the host->device transfer + // stream now otherwise we will create a race condition. + auto* gpu_device_context = + static_cast(ctx->op_device_context()); + gpu_device_context->stream()->ThenWaitFor( + gpu_device_context->host_to_device_stream()); + } } else { // No copy required. ctx->set_output(i, const_tensor); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 392ad90..1700c97 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -87,4 +87,11 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } +ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { + rng_seed_ = rng_seed; + return *this; +} + +int ExecutableRunOptions::rng_seed() const { return rng_seed_; } + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index d4fcbf0..2c1d9ff 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -84,6 +84,9 @@ class ExecutableRunOptions { DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; + ExecutableRunOptions& set_rng_seed(int rng_seed); + int rng_seed() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -92,6 +95,7 @@ class ExecutableRunOptions { tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; + int rng_seed_ = 0; }; } // namespace xla