internal change
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 10:23:54 +0000 (03:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 10:26:08 +0000 (03:26 -0700)
PiperOrigin-RevId: 191869400

tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_device_context.h
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/xla/executable_run_options.cc
tensorflow/compiler/xla/executable_run_options.h

index 24aa203..a492fc6 100644 (file)
@@ -204,14 +204,14 @@ cc_library(
         ":common",
         ":xla_compilation_cache",
         ":xla_tensor",
+        "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
-        "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:local_client",
-        "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_runtime",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
index 2d6511a..f48941f 100644 (file)
@@ -155,6 +155,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   options.graph_def_version = ctx->function_library()->graph_def_version();
   options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
   options.device_allocator = xla_allocator;
+  // TODO(b/77671268): We don't set variable_representation_shape_fn here. This
+  // is restricted to Variables, but we need something like this to apply to
+  // normal Tensors too.
 
   const XlaCompiler::CompilationResult* kernel;
   xla::LocalExecutable* executable;
@@ -179,8 +182,10 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   run_options.set_stream(stream);
   run_options.set_allocator(xla_allocator);
   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+  run_options.set_rng_seed(ctx->step_id());
   Env* env = Env::Default();
   auto start_time = env->NowMicros();
+
   auto run_result = executable->Run(launch_context.arguments(), run_options);
   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
 
index 6a57831..43eb164 100644 (file)
@@ -19,6 +19,7 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/platform/mem.h"
 
@@ -53,8 +54,33 @@ XlaTransferManager::XlaTransferManager(se::Stream* stream,
                                        bool transfer_as_literal)
     : stream_(stream),
       client_(client),
+      transfer_manager_(client->backend().transfer_manager()),
       transfer_as_literal_(transfer_as_literal) {}
 
+Status XlaTransferManager::TransferLiteralToDevice(
+    const Tensor& host_tensor, Tensor* device_tensor) const {
+  xla::Literal literal;
+  TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal));
+  VLOG(1) << "Transfer to device as literal: " << literal.ToString();
+
+  const xla::ShapedBuffer& shaped_buffer =
+      XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+  return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
+                                                    shaped_buffer);
+}
+
+Status XlaTransferManager::TransferLiteralFromDevice(
+    Tensor* host_tensor, const Tensor& device_tensor) const {
+  const xla::ShapedBuffer& shaped_buffer =
+      XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
+                      transfer_manager_->TransferLiteralFromDevice(
+                          stream_->parent(), shaped_buffer));
+  VLOG(1) << "Transfer from device as literal: " << literal->ToString();
+  return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
+}
+
 void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
                                                Device* device,
                                                Tensor* device_tensor,
@@ -86,9 +112,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
         XlaTensor::DeviceMemoryFromTensor(*device_tensor);
     Status status;
     if (transfer_as_literal_) {
-      status = xla::Unimplemented(
-          "XlaTransferManager::CopyCPUTensorToDevice not implemented for "
-          "literals");
+      status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
     } else {
       stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
       // TODO(hpucha): Make this asynchronous.
@@ -129,9 +153,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
 
     Status status;
     if (transfer_as_literal_) {
-      status = xla::Unimplemented(
-          "XlaTransferManager::CopyDeviceTensorToCPU not implemented for "
-          "literals");
+      status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
     } else {
       stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
       // TODO(hpucha): Make this asynchronous.
index a8ad511..ad914a1 100644 (file)
@@ -57,11 +57,18 @@ class XlaTransferManager {
   perftools::gputools::Stream* stream() const { return stream_; }
 
  private:
+  Status TransferLiteralToDevice(const Tensor& host_tensor,
+                                 Tensor* device_tensor) const;
+  Status TransferLiteralFromDevice(Tensor* host_tensor,
+                                   const Tensor& device_tensor) const;
+
   // Stream obtained from a Device, used to transfer tensors between
   // CPU and device.
   perftools::gputools::Stream* stream_;
   // For the underlying memory allocator and XLA's TransferManager.
   xla::LocalClient* client_;
+  // Transfer manager, for marshalling data to and from the device.
+  xla::TransferManager* transfer_manager_;
   // True if we must use XLA's TransferManager for correct device transfers.
   bool transfer_as_literal_;
 };
index 354be1e..50b0061 100644 (file)
@@ -16,12 +16,14 @@ limitations under the License.
 #include "tensorflow/compiler/jit/xla_launch_util.h"
 
 #include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
 #include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
 #include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/op.h"
@@ -165,6 +167,8 @@ void XlaComputationLaunchContext::PopulateOutputs(
   // Computation output should always be a tuple.
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
+    VLOG(2) << "Result tuple shape (on device): "
+            << output->on_device_shape().DebugString();
   }
   CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
 
@@ -179,6 +183,10 @@ void XlaComputationLaunchContext::PopulateOutputs(
       const size_t total_bytes = const_tensor.TotalBytes();
       if (stream && total_bytes > 0) {
         // Copy host -> device. (Empty tensors don't have backing buffers.)
+        // Manually allocate memory using an XlaTensorBuffer so we can allocate
+        // as much memory as the device requires (as given by
+        // GetByteSizeRequirement). This avoids XlaTransferManager having to
+        // reallocate the device buffer later.
         VLOG(1) << "Constant output tensor on device";
 
         OP_REQUIRES_OK(
@@ -189,15 +197,23 @@ void XlaComputationLaunchContext::PopulateOutputs(
                                   client_, stream->parent()->device_ordinal()));
         }
 
-        const void* src_ptr = DMAHelper::base(&const_tensor);
-        gpu::DeviceMemoryBase dst_ptr =
-            XlaTensor::DeviceMemoryFromTensor(*output_tensor);
-        // Memcpying asynchronously is safe for the GPU, but the CPU uses a
-        // shared allocator so hold a reference to the copied-to buffer until
-        // complete.
-        TensorReference ref(*output_tensor);
-        stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes);
-        stream->ThenDoHostCallback([ref] { ref.Unref(); });
+        Device* device = dynamic_cast<Device*>(ctx->device());
+        OP_REQUIRES(ctx, device != nullptr,
+                    errors::Internal("DeviceBase was not a Device."));
+        ctx->op_device_context()->CopyCPUTensorToDevice(
+            &const_tensor, device, output_tensor,
+            [&](Status status) { TF_CHECK_OK(status); });
+
+        if (device->device_type() == DEVICE_GPU) {
+          // The GPUDeviceContext enqueues the host->device transfer in a
+          // separate stream from the main compute stream. We must ensure the
+          // compute stream is synchronized with the host->device transfer
+          // stream now otherwise we will create a race condition.
+          auto* gpu_device_context =
+              static_cast<GPUDeviceContext*>(ctx->op_device_context());
+          gpu_device_context->stream()->ThenWaitFor(
+              gpu_device_context->host_to_device_stream());
+        }
       } else {
         // No copy required.
         ctx->set_output(i, const_tensor);
index 392ad90..1700c97 100644 (file)
@@ -87,4 +87,11 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
   return device_assignment_;
 }
 
+ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
+  rng_seed_ = rng_seed;
+  return *this;
+}
+
+int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
+
 }  // namespace xla
index d4fcbf0..2c1d9ff 100644 (file)
@@ -84,6 +84,9 @@ class ExecutableRunOptions {
       DeviceAssignment* device_assignment);
   const DeviceAssignment* device_assignment() const;
 
+  ExecutableRunOptions& set_rng_seed(int rng_seed);
+  int rng_seed() const;
+
  private:
   DeviceMemoryAllocator* allocator_ = nullptr;
   int device_ordinal_ = -1;
@@ -92,6 +95,7 @@ class ExecutableRunOptions {
   tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
   const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
   ExecutionProfile* execution_profile_ = nullptr;
+  int rng_seed_ = 0;
 };
 
 }  // namespace xla