[tf2xla] Introduce XlaTensorInfo
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 16 Mar 2018 10:45:32 +0000 (03:45 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 16 Mar 2018 10:49:44 +0000 (03:49 -0700)
XlaTensorInfo is side-band data for Tensors. It can be used to store
information about Tensors that is not possible to store in the Tensor
itself.  The XlaTensorInfos are managed by XlaTensorInfoManager, which
is an Allocator, which allows it to release the TensorInfos when the
underlying Tensor is released.  Looking up an XlaTensorInfo for a
Tensor requires a hash table lookup. This implementation keeps this
off the fast path and only looks the tensorinfos up when they are
required.

PiperOrigin-RevId: 189319553

13 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_cpu_device.cc
tensorflow/compiler/jit/xla_device.cc
tensorflow/compiler/jit/xla_device.h
tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_device_context.h
tensorflow/compiler/jit/xla_gpu_device.cc
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_launch_util.h
tensorflow/compiler/jit/xla_tensor_info.cc [new file with mode: 0644]
tensorflow/compiler/jit/xla_tensor_info.h [new file with mode: 0644]
tensorflow/core/framework/tensor.h

index c4a2d4a..39eb390 100644 (file)
@@ -119,6 +119,21 @@ cc_library(
 )
 
 cc_library(
+    name = "xla_tensor_info",
+    srcs = ["xla_tensor_info.cc"],
+    hdrs = ["xla_tensor_info.h"],
+    deps = [
+        ":common",
+        "//tensorflow/compiler/xla/service:shaped_buffer",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+cc_library(
     name = "xla_device",
     srcs = [
         "xla_device.cc",
@@ -136,6 +151,7 @@ cc_library(
         ":common",
         ":jit_compilation_passes",
         ":xla_launch_util",
+        ":xla_tensor_info",
         "//tensorflow/compiler/jit/ops:xla_ops",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:dump_graph",
@@ -182,6 +198,7 @@ cc_library(
     deps = [
         ":common",
         ":xla_compilation_cache",
+        ":xla_tensor_info",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
index cd7f8dd..e24a9a0 100644 (file)
@@ -114,10 +114,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   // this is more obviously correct.)
   core::ScopedUnref cache_ref(cache);
 
+  const XlaDevice::Metadata* metadata;
+  Status s = XlaDevice::GetMetadata(ctx, &metadata);
+
+  XlaTensorInfoManager* tensor_info_manager = nullptr;
+  if (s.ok()) {
+    tensor_info_manager = &metadata->tensor_info_manager();
+  }
+
   // Get the platform_id_ for XLA_* devices.
   if (platform_id_ == nullptr) {
-    const XlaDevice::Metadata* metadata;
-    Status s = XlaDevice::GetMetadata(ctx, &metadata);
     if (s.ok()) {
       platform_id_ = metadata->platform()->id();
     }
@@ -148,8 +154,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
 
   VLOG(1) << "Executing XLA Computation...";
 
-  XlaComputationLaunchContext launch_context(num_resource_args_, client,
-                                             &xla_allocator);
+  XlaComputationLaunchContext launch_context(
+      num_resource_args_, client, &xla_allocator, tensor_info_manager);
   launch_context.PopulateInputs(ctx, kernel, variables);
 
   // Execute the computation.
@@ -166,8 +172,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   auto elapsed = env->NowMicros() - start_time;
   VLOG(2) << "Elapsed time: " << elapsed << "us";
 
-  launch_context.PopulateOutputs(ctx, kernel,
-                                 run_result.ConsumeValueOrDie()->release());
+  launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
   VLOG(1) << "Done";
 }
 
index e238252..db3bf3e 100644 (file)
@@ -39,9 +39,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
   (void)registrations;
 
   std::unique_ptr<XlaDevice> device;
-  TF_RETURN_IF_ERROR(XlaDevice::Create(
-      "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix,
-      /*register_device_for_compilation=*/true, &device));
+  TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+                                       DEVICE_CPU_XLA_JIT, options, name_prefix,
+                                       /*register_device_for_compilation=*/true,
+                                       /*transfer_as_literal=*/false, &device));
   devices->push_back(device.release());
   return Status::OK();
 }
index d4d8fe1..e4e11d4 100644 (file)
@@ -109,7 +109,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
     const string& platform_name, const string& device_name, int device_ordinal,
     const string& jit_device_name, const SessionOptions& options,
     const string& name_prefix, bool register_device_for_compilation,
-    std::unique_ptr<XlaDevice>* device) {
+    bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
   VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
           << device_ordinal;
 
@@ -137,15 +137,17 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
 
   device->reset(new XlaDevice(options, attrs, device_ordinal,
                               DeviceType(jit_device_name),
-                              platform.ValueOrDie()));
+                              platform.ValueOrDie(), transfer_as_literal));
   return Status::OK();
 }
 
-XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
-                              const DeviceType& device_type)
+XlaDevice::Metadata::Metadata(
+    int device_ordinal, se::Platform* platform, const DeviceType& device_type,
+    std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager)
     : device_ordinal_(device_ordinal),
       device_type_(device_type),
-      platform_(platform) {}
+      platform_(platform),
+      tensor_info_manager_(*tensor_info_manager) {}
 
 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
 
@@ -160,6 +162,10 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
   return device_type_;
 }
 
+XlaTensorInfoManager& XlaDevice::Metadata::tensor_info_manager() const {
+  return *tensor_info_manager_;
+}
+
 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
                                            const Metadata** metadata) {
   XlaDevice* xla_device =
@@ -177,13 +183,19 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
 
 XlaDevice::XlaDevice(const SessionOptions& options,
                      const DeviceAttributes& attrs, int device_ordinal,
-                     const DeviceType& jit_device_name, se::Platform* platform)
+                     const DeviceType& jit_device_name, se::Platform* platform,
+                     bool transfer_as_literal)
     : LocalDevice(options, attrs),
-      xla_metadata_(device_ordinal, platform, jit_device_name),
+      xla_metadata_(
+          device_ordinal, platform, jit_device_name,
+          // Pass tensor_info_manager_ by reference as it is initialized lazily.
+          &tensor_info_manager_),
       device_ordinal_(device_ordinal),
       jit_device_name_(jit_device_name),
       xla_allocator_(nullptr),
-      platform_(platform) {}
+      platform_(platform),
+      tensor_info_manager_(nullptr),
+      transfer_as_literal_(transfer_as_literal) {}
 
 XlaDevice::~XlaDevice() {}
 
@@ -208,6 +220,7 @@ Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
     xla::Backend* backend = client()->mutable_backend();
     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
         backend, device_ordinal_);
+    tensor_info_manager_.reset(new XlaTensorInfoManager(xla_allocator_));
   }
   return xla_allocator_;
 }
@@ -225,7 +238,11 @@ Status XlaDevice::FillContextMap(const Graph* graph,
   VLOG(1) << "XlaDevice::FillContextMap";
   device_context_map->resize(graph->num_node_ids());
   TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
-  auto ctx = new XlaDeviceContext(stream);
+  // Call GetAllocator for the side-effect of ensuring the allocator and
+  // XlaTensorInfoManager is created.
+  (void)GetAllocator({});
+  auto ctx = new XlaDeviceContext(stream, tensor_info_manager_.get(),
+                                  transfer_as_literal_);
   for (Node* n : graph->nodes()) {
     VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
     ctx->Ref();
@@ -273,7 +290,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
     Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
     Notification n;
     TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
-    XlaTransferManager manager(stream);
+    XlaTransferManager manager(stream, tensor_info_manager_.get(),
+                               transfer_as_literal_);
     manager.CopyCPUTensorToDevice(&parsed, this, &copy,
                                   [&n, &status](const Status& s) {
                                     status = s;
index d2ec382..0f44762 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
 
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/local_device.h"
@@ -48,7 +49,8 @@ class XlaDevice : public LocalDevice {
   class Metadata {
    public:
     Metadata(int device_ordinal, perftools::gputools::Platform* platform,
-             const DeviceType& device_type);
+             const DeviceType& device_type,
+             std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager);
 
     // The index of the device on this host.
     int device_ordinal() const;
@@ -56,11 +58,13 @@ class XlaDevice : public LocalDevice {
     perftools::gputools::Platform* platform() const;
     xla::LocalClient* client() const;
     const DeviceType& jit_device_type() const;
+    XlaTensorInfoManager& tensor_info_manager() const;
 
    private:
     const int device_ordinal_;
     const DeviceType device_type_;
     perftools::gputools::Platform* platform_;  // Not owned.
+    std::unique_ptr<XlaTensorInfoManager>& tensor_info_manager_;
 
     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
   };
@@ -71,15 +75,20 @@ class XlaDevice : public LocalDevice {
   // Factory function. 'platform_name' is the name of the XLA platform.
   // 'device_name' is the name of the Tensorflow device to create.
   // 'jit_device_name' is the name of the corresponding JIT device.
+  // 'transfer_as_literal' is true if device<->host transfers must be done using
+  // XLA's TransferLiteral{To,From}Device interface. If false, we can use
+  // ThenMemcpy instead.
   static Status Create(const string& platform_name, const string& device_name,
                        int device_ordinal, const string& jit_device_name,
                        const SessionOptions& options, const string& name_prefix,
                        bool register_device_for_compilation,
+                       bool transfer_as_literal,
                        std::unique_ptr<XlaDevice>* device);
 
   XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
             int device_ordinal, const DeviceType& jit_device_name,
-            ::perftools::gputools::Platform* platform);
+            ::perftools::gputools::Platform* platform,
+            bool transfer_as_literal);
   ~XlaDevice() override;
 
   Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -113,6 +122,16 @@ class XlaDevice : public LocalDevice {
   // copying back and forth between CPU and the device, and
   // computations enqueued by XLA.
   xla::Backend::StreamPtr stream_;
+  // Manages sideband data about tensors, in particular the on-device shape tree
+  // if the tensor requires multiple device buffers to represent (for example,
+  // tuple shapes).
+  // This is a unique_ptr because XlaTensorInfoManager is non-copy-constructible
+  // and we need to initialize this lazily (as we also lazily initialize the
+  // underlying allocator).
+  std::unique_ptr<XlaTensorInfoManager> tensor_info_manager_;
+  // Must we use XLA's transfer manager for correct host<->device transfers? if
+  // false, we can use ThenMemcpy() instead.
+  bool transfer_as_literal_;
 };
 
 // Builds dummy OpKernel registrations on 'device' for the JIT operators
index c936222..b57f82f 100644 (file)
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/jit/xla_device_context.h"
 
+#include "tensorflow/compiler/jit/xla_launch_util.h"
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -52,7 +53,12 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
 
 void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
 
-XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {}
+XlaTransferManager::XlaTransferManager(
+    se::Stream* stream, XlaTensorInfoManager* tensor_info_manager,
+    bool transfer_as_literal)
+    : stream_(stream),
+      tensor_info_manager_(tensor_info_manager),
+      transfer_as_literal_(transfer_as_literal) {}
 
 void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
                                                Device* device,
@@ -72,13 +78,19 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
     se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes);
 
     Status status;
-    stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
-    // TODO(hpucha): Make this asynchronous.
-    Status block_status = stream_->BlockHostUntilDone();
-    if (!block_status.ok()) {
-      status = xla::InternalError(
-          "Failed to complete data transfer on stream %p: %s", stream_,
-          block_status.error_message().c_str());
+    if (transfer_as_literal_) {
+      status = xla::Unimplemented(
+          "XlaTransferManager::CopyCPUTensorToDevice not implemented for "
+          "literals");
+    } else {
+      stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
+      // TODO(hpucha): Make this asynchronous.
+      Status block_status = stream_->BlockHostUntilDone();
+      if (!block_status.ok()) {
+        status = xla::InternalError(
+            "Failed to complete data transfer on stream %p: %s", stream_,
+            block_status.error_message().c_str());
+      }
     }
 
     done(status);
@@ -108,13 +120,19 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
     void* dst_ptr = DMAHelper::base(cpu_tensor);
 
     Status status;
-    stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
-    // TODO(hpucha): Make this asynchronous.
-    Status block_status = stream_->BlockHostUntilDone();
-    if (!block_status.ok()) {
-      status = xla::InternalError(
-          "Failed to complete data transfer on stream %p: %s", stream_,
-          block_status.error_message().c_str());
+    if (transfer_as_literal_) {
+      status = xla::Unimplemented(
+          "XlaTransferManager::CopyDeviceTensorToCPU not implemented for "
+          "literals");
+    } else {
+      stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
+      // TODO(hpucha): Make this asynchronous.
+      Status block_status = stream_->BlockHostUntilDone();
+      if (!block_status.ok()) {
+        status = xla::InternalError(
+            "Failed to complete data transfer on stream %p: %s", stream_,
+            block_status.error_message().c_str());
+      }
     }
 
     done(status);
@@ -125,7 +143,10 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
   done(Status::OK());
 }
 
-XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {}
+XlaDeviceContext::XlaDeviceContext(se::Stream* stream,
+                                   XlaTensorInfoManager* tensor_info_manager,
+                                   bool transfer_as_literal)
+    : manager_(stream, tensor_info_manager, transfer_as_literal) {}
 
 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
                                              Device* device,
index c4edcd4..df02f4e 100644 (file)
@@ -18,6 +18,7 @@ limitations under the License.
 
 #include <memory>
 
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/framework/allocator.h"
@@ -49,7 +50,9 @@ class XlaDeviceAllocator : public Allocator {
 // Helper class for managing data transfers between host and XLA devices.
 class XlaTransferManager {
  public:
-  explicit XlaTransferManager(perftools::gputools::Stream* stream);
+  explicit XlaTransferManager(perftools::gputools::Stream* stream,
+                              XlaTensorInfoManager* tensor_info_manager,
+                              bool transfer_as_literal);
 
   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                              Tensor* device_tensor, StatusCallback done) const;
@@ -62,6 +65,10 @@ class XlaTransferManager {
   // Stream obtained from a Device, used to transfer tensors between
   // CPU and device.
   perftools::gputools::Stream* stream_;
+  // The tensor info manager, for access to sideband information about tensors.
+  XlaTensorInfoManager* tensor_info_manager_;
+  // True if we must use XLA's TransferManager for correct device transfers.
+  bool transfer_as_literal_;
 };
 
 // DeviceContext for operators assigned to XlaDevice devices. The
@@ -69,7 +76,9 @@ class XlaTransferManager {
 // wraps the methods in XlaTransferManager.
 class XlaDeviceContext : public DeviceContext {
  public:
-  explicit XlaDeviceContext(perftools::gputools::Stream* stream);
+  explicit XlaDeviceContext(perftools::gputools::Stream* stream,
+                            XlaTensorInfoManager* tensor_info_manager,
+                            bool transfer_as_literal);
 
   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                              Tensor* device_tensor,
index 2326070..383ed87 100644 (file)
@@ -39,9 +39,10 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
   (void)registrations;
 
   std::unique_ptr<XlaDevice> device;
-  Status status = XlaDevice::Create(
-      "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix,
-      /*register_device_for_compilation=*/true, &device);
+  Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0,
+                                    DEVICE_GPU_XLA_JIT, options, name_prefix,
+                                    /*register_device_for_compilation=*/true,
+                                    /*transfer_as_literal=*/false, &device);
   if (!status.ok()) {
     // Treat failures as non-fatal; there might not be a GPU in the machine.
     VLOG(1) << "Failed to create XLA_GPU device: " << status;
index 8322dd2..689fa32 100644 (file)
@@ -56,74 +56,56 @@ XlaAllocator::XlaAllocator(const gpu::Platform* platform,
                            OpKernelContext* op_context)
     : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
 
-XlaAllocator::~XlaAllocator() = default;
+XlaAllocator::~XlaAllocator() { CHECK(allocated_.empty()); }
 
 xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
     int device_ordinal, uint64 size, bool retry_on_failure) {
-  AllocatorAttributes allocator_attrs;
-  allocator_attrs.set_on_host(false);
-
-  AllocationAttributes allocation_attrs;
-  allocation_attrs.no_retry_on_failure = !retry_on_failure;
-
-  Tensor t;
-  Status status = op_context_->allocate_temp(
-      DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
-      allocation_attrs);
-  if (!status.ok()) {
-    VLOG(2) << "Allocation failed " << size;
-    return status;
-  }
-  void* data =
-      reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
-  tensors_[data] = t;
+  void* data = op_context_->device()->GetAllocator({})->AllocateRaw(
+      Allocator::kAllocatorAlignment, size);
+  allocated_.insert(data);
   return gpu::DeviceMemoryBase(data, size);
 }
 
-Status XlaAllocator::RegisterArgument(const Tensor* t) {
-  void* data =
-      reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
-  tensors_[data] = *t;
-  return Status::OK();
-}
+void XlaAllocator::Release(void* ptr) { allocated_.erase(ptr); }
 
 Status XlaAllocator::Deallocate(int device_ordinal,
                                 gpu::DeviceMemoryBase* mem) {
-  if (mem->opaque() != nullptr) {
-    if (tensors_.erase(mem->opaque()) == 0) {
-      return tensorflow::errors::InvalidArgument("Unknown tensor address");
-    }
+  if (allocated_.count(mem->opaque())) {
+    op_context_->device()->GetAllocator({})->DeallocateRaw(mem->opaque());
+    allocated_.erase(mem->opaque());
   }
   return Status::OK();
 }
 
-Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
-                                          DataType dtype,
-                                          const TensorShape& shape,
-                                          Tensor* out_tensor) const {
-  void* ptr = const_cast<void*>(buffer.opaque());
-  auto it = tensors_.find(ptr);
-  if (it == tensors_.end()) {
-    return errors::InvalidArgument("Unknown tensor address");
-  }
-  const Tensor& tensor = it->second;
-
-  int64 output_size = DataTypeSize(dtype) * shape.num_elements();
-  if (tensor.TotalBytes() == output_size) {
-    out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
-  } else {
-    Tensor slice = tensor.Slice(0, output_size);
-    out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
-  }
-  return Status::OK();
+namespace {
+// Return the 'index''th subtree of the given ShapedBuffer as a ShapedBuffer.
+xla::ShapedBuffer ExtractSubShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
+                                         int index) {
+  xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape(
+      shaped_buffer.on_host_shape(), index);
+  xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape(
+      shaped_buffer.on_device_shape(), index);
+
+  xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
+                                      shaped_buffer.platform(),
+                                      shaped_buffer.device_ordinal());
+
+  auto& shape_tree = shaped_buffer.buffers();
+  auto& sub_shape_tree = sub_shaped_buffer.buffers();
+  sub_shape_tree.CopySubtreeFrom(shape_tree,
+                                 /*source_base_index=*/{index},
+                                 /*target_base_index=*/{});
+  return sub_shaped_buffer;
 }
+}  // namespace
 
 XlaComputationLaunchContext::XlaComputationLaunchContext(
     int64 num_resource_args, xla::LocalClient* client,
-    XlaAllocator* xla_allocator)
+    XlaAllocator* xla_allocator, XlaTensorInfoManager* tensor_info_manager)
     : num_resource_args_(num_resource_args),
       client_(client),
-      xla_allocator_(xla_allocator) {}
+      xla_allocator_(xla_allocator),
+      tensor_info_manager_(tensor_info_manager) {}
 
 void XlaComputationLaunchContext::PopulateInputs(
     OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
@@ -145,29 +127,35 @@ void XlaComputationLaunchContext::PopulateInputs(
       t = &(ctx->input(arg_num));
     }
 
-    gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
-        const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
-
     const xla::Shape on_device_shape =
         client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
-    CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
-        << "On-device shape "
-        << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
-        << " not the same as on-host shape "
-        << xla::ShapeUtil::HumanStringWithLayout(shape);
-    arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
-        /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(),
-        client_->default_device_ordinal());
-    arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
-    arg_ptrs_[i] = arg_buffers_[i].get();
-
-    OP_REQUIRES_OK(ctx, xla_allocator_->RegisterArgument(t));
+    if (xla::ShapeUtil::IsTuple(on_device_shape)) {
+      CHECK(tensor_info_manager_);
+      const XlaTensorInfo* tensor_info =
+          tensor_info_manager_->GetTensorInfo(*t);
+      CHECK(tensor_info && tensor_info->has_shaped_buffer());
+      arg_ptrs_[i] =
+          const_cast<xla::ShapedBuffer*>(&tensor_info->shaped_buffer());
+    } else {
+      CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
+          << "On-device shape "
+          << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
+          << " not the same as on-host shape "
+          << xla::ShapeUtil::HumanStringWithLayout(shape);
+      gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
+          const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
+      arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
+          /*on_host_shape=*/shape, /*on_device_shape=*/shape,
+          client_->platform(), client_->default_device_ordinal());
+      arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
+      arg_ptrs_[i] = arg_buffers_[i].get();
+    }
   }
 }
 
 void XlaComputationLaunchContext::PopulateOutputs(
     OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
-    std::unique_ptr<xla::ShapedBuffer> output) {
+    std::unique_ptr<xla::ScopedShapedBuffer> output) {
   gpu::Stream* stream =
       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
 
@@ -180,6 +168,11 @@ void XlaComputationLaunchContext::PopulateOutputs(
   // Copy XLA results to the OpOutputList.
   int output_num = 0;
   for (int i = 0; i < ctx->num_outputs(); ++i) {
+    AllocatorAttributes alloc_attrs = ctx->output_alloc_attr(i);
+    Allocator* allocator = ctx->device()->GetAllocator(alloc_attrs);
+    if (tensor_info_manager_ && !alloc_attrs.on_host()) {
+      allocator = tensor_info_manager_;
+    }
     if (kernel->outputs[i].is_constant) {
       // Output is a constant.
       const Tensor& const_tensor = kernel->outputs[i].constant_value;
@@ -204,11 +197,19 @@ void XlaComputationLaunchContext::PopulateOutputs(
       VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
 
       gpu::DeviceMemoryBase buffer = output->buffer({output_num});
-      Tensor output_tensor;
-      // Looks up the owning Tensor by buffer address.
-      OP_REQUIRES_OK(ctx, xla_allocator_->MakeTensorFromBuffer(
-                              buffer, ctx->expected_output_dtype(i), shape,
-                              &output_tensor));
+      Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+          ctx->expected_output_dtype(i), shape, buffer, allocator);
+      xla_allocator_->Release(buffer.opaque());
+
+      xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape(
+          output->on_device_shape(), output_num);
+      if (xla::ShapeUtil::IsTuple(output_shape)) {
+        CHECK(tensor_info_manager_);
+        XlaTensorInfo* tensor_info =
+            tensor_info_manager_->GetOrCreateTensorInfo(output_tensor);
+        tensor_info->set_shaped_buffer(
+            ExtractSubShapedBuffer(*output, output_num));
+      }
       ctx->set_output(i, output_tensor);
       ++output_num;
     }
@@ -221,6 +222,10 @@ void XlaComputationLaunchContext::PopulateOutputs(
   // Apply variable updates, if any.
   VLOG(2) << "Applying variable updates";
   for (int i = 0; i < kernel->resource_updates.size(); ++i) {
+    Allocator* allocator = ctx->device()->GetAllocator({});
+    if (tensor_info_manager_) {
+      allocator = tensor_info_manager_;
+    }
     const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
     OP_REQUIRES(ctx,
                 write.input_index >= 0 && write.input_index < ctx->num_inputs(),
@@ -243,11 +248,19 @@ void XlaComputationLaunchContext::PopulateOutputs(
     mutex_lock ml(*variable->mu());
     OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
                 errors::Internal("Mismatched type in variable write"));
-
-    // Looks up the owning Tensor by buffer address.
-    OP_REQUIRES_OK(ctx,
-                   xla_allocator_->MakeTensorFromBuffer(
-                       buffer, write.type, write.shape, variable->tensor()));
+    *variable->tensor() =
+        XlaTensorBuffer::MakeTensor(write.type, write.shape, buffer, allocator);
+    xla_allocator_->Release(buffer.opaque());
+
+    xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape(
+        output->on_device_shape(), output_num);
+    if (xla::ShapeUtil::IsTuple(output_shape)) {
+      CHECK(tensor_info_manager_);
+      XlaTensorInfo* tensor_info =
+          tensor_info_manager_->GetOrCreateTensorInfo(*variable->tensor());
+      tensor_info->set_shaped_buffer(
+          ExtractSubShapedBuffer(*output, output_num));
+    }
     ++output_num;
   }
 }
index 9fd356f..8694f6c 100644 (file)
@@ -19,8 +19,10 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
 
 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/allocation_description.pb.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/variable_ops.h"
@@ -52,16 +54,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
   Status Deallocate(int device_ordinal,
                     perftools::gputools::DeviceMemoryBase* mem) override;
 
-  // Register an Tensor (input or resource variable) with the allocator. If
-  // the operation returns an alias to one of its inputs, then the allocator
-  // needs to be able to handle it.
-  Status RegisterArgument(const Tensor* t);
-
-  // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
-  // interpreted as having data type 'dtype' and shape 'shape'.
-  Status MakeTensorFromBuffer(perftools::gputools::DeviceMemoryBase buffer,
-                              DataType dtype, const TensorShape& shape,
-                              Tensor* out_tensor) const;
+  // Un-track 'ptr' - do not delete it on destruction.
+  void Release(void* ptr);
 
   // The Tensorflow BFC allocator used on GPU allows host-side deallocation
   // before GPU execution takes place. Tensorflow uses the ordering of the main
@@ -74,11 +68,7 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
 
  private:
   OpKernelContext* const op_context_;
-
-  // Map from pointer address to the owning Tensor; used by
-  // MakeTensorFromBuffer. Also used to automatically release Tensors when the
-  // allocator is freed.
-  std::unordered_map<void*, Tensor> tensors_;
+  std::unordered_set<void*> allocated_;
 };
 
 // Helper class to perform the marshalling of TensorFlow inputs and outputs to
@@ -86,7 +76,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
 class XlaComputationLaunchContext {
  public:
   XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
-                              XlaAllocator* xla_allocator);
+                              XlaAllocator* xla_allocator,
+                              XlaTensorInfoManager* tensor_info_manager);
 
   // Add all inputs within `ctx` as XLA arguments (returned by arguments()).
   // `variables` is a map from TensorFlow argument number to resource variable.
@@ -97,7 +88,7 @@ class XlaComputationLaunchContext {
   // Given the XLA output in `output`, populate all outputs of `ctx`.
   void PopulateOutputs(OpKernelContext* ctx,
                        const XlaCompiler::CompilationResult* kernel,
-                       std::unique_ptr<xla::ShapedBuffer> output);
+                       std::unique_ptr<xla::ScopedShapedBuffer> output);
 
   // Return the argument list. Only valid after PopulateInputs() has been
   // called.
@@ -107,10 +98,52 @@ class XlaComputationLaunchContext {
   int64 num_resource_args_;
   xla::LocalClient* client_;
   XlaAllocator* xla_allocator_;
+  XlaTensorInfoManager* tensor_info_manager_;
   std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
   std::vector<xla::ShapedBuffer*> arg_ptrs_;
 };
 
+// A simple TensorBuffer implementation that allows us to create Tensors that
+// take ownership of pre-allocated memory.
+class XlaTensorBuffer : public TensorBuffer {
+ public:
+  XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size,
+                  Allocator* allocator)
+      : expected_size_(expected_size),
+        actual_size_(actual_size),
+        allocator_(allocator) {
+    data_ = const_cast<void*>(ptr);
+  }
+
+  ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); }
+
+  void* data() const override { return data_; }
+  size_t size() const override { return expected_size_; }
+
+  TensorBuffer* root_buffer() override { return this; }
+
+  void FillAllocationDescription(AllocationDescription* proto) const override {
+    proto->set_allocated_bytes(actual_size_);
+  }
+
+  static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
+                           perftools::gputools::DeviceMemoryBase buffer,
+                           Allocator* allocator) {
+    size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
+    auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
+                                              buffer.size(), allocator);
+    Tensor t(dtype, shape, tensor_buffer);
+    tensor_buffer->Unref();
+    return t;
+  }
+
+ private:
+  void* data_;
+  size_t expected_size_;
+  size_t actual_size_;
+  Allocator* allocator_;
+};
+
 }  // namespace tensorflow
 
 #endif
diff --git a/tensorflow/compiler/jit/xla_tensor_info.cc b/tensorflow/compiler/jit/xla_tensor_info.cc
new file mode 100644 (file)
index 0000000..0ce18c2
--- /dev/null
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
+
+namespace tensorflow {
+
+const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo(
+    const void* device_ptr) const {
+  mutex_lock lock(lock_);
+  auto iterator = tensor_infos_.find(device_ptr);
+  return (iterator == tensor_infos_.end()) ? nullptr
+                                           : tensor_infos_.at(device_ptr).get();
+}
+
+XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo(
+    const void* device_ptr) {
+  mutex_lock lock(lock_);
+  auto iterator = tensor_infos_.find(device_ptr);
+  if (iterator != tensor_infos_.end()) {
+    return iterator->second.get();
+  }
+  auto iterator_and_inserted =
+      tensor_infos_.emplace(device_ptr, MakeUnique<XlaTensorInfo>());
+  CHECK(iterator_and_inserted.second);
+  return iterator_and_inserted.first->second.get();
+}
+
+const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo(const Tensor& tensor) {
+  return GetTensorInfo(tensor.tensor_data().data());
+}
+
+XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo(
+    const Tensor& tensor) {
+  return GetOrCreateTensorInfo(tensor.tensor_data().data());
+}
+
+void XlaTensorInfoManager::DeallocateRaw(void* ptr) {
+  wrapped()->DeallocateRaw(ptr);
+  mutex_lock lock(lock_);
+  tensor_infos_.erase(ptr);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_tensor_info.h b/tensorflow/compiler/jit/xla_tensor_info.h
new file mode 100644 (file)
index 0000000..0b0736b
--- /dev/null
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_
+
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Information about a tensor. The XlaTensorInfoManager can maintain one of
+// these per device Tensor.
+class XlaTensorInfo {
+ public:
+  XlaTensorInfo() {}
+
+  // Some Tensors can have complex on-device shapes, including tuple shapes. To
+  // manage the memory for these tensors a ShapedBuffer may be required.
+
+  // Return true if this TensorInfo contains a ShapedBuffer.
+  bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
+  // Return the contained ShapedBuffer.
+  // REQUIRES: has_shaped_buffer()
+  const xla::ShapedBuffer& shaped_buffer() const { return *shaped_buffer_; }
+  // Mutates the TensorInfo to set the ShapedBuffer.
+  void set_shaped_buffer(xla::ShapedBuffer shaped_buffer) {
+    shaped_buffer_.reset(new xla::ShapedBuffer(std::move(shaped_buffer)));
+  }
+
+ private:
+  // The optional contained ShapedBuffer.
+  std::unique_ptr<xla::ShapedBuffer> shaped_buffer_;
+};
+
+// Manages XlaTensorInfo objects. This class is also an Allocator, so that
+// XlaTensorInfo objects can be deleted when their Tensor is deallocated.
+class XlaTensorInfoManager : public AllocatorWrapper {
+ public:
+  // Creates a new XlaTensorInfoManager, delegating all DeallocateRaw calls to
+  // allocator.
+  XlaTensorInfoManager(Allocator* allocator) : AllocatorWrapper(allocator) {}
+
+  // Returns the XlaTensorInfo for the given device memory pointer or nullptr if
+  // none exists.
+  const XlaTensorInfo* GetTensorInfo(const void* device_ptr) const;
+  // Returns the XlaTensorInfo for the device memory pointer extracted from
+  // tensor or nullptr if none exists.
+  const XlaTensorInfo* GetTensorInfo(const Tensor& tensor);
+
+  // Returns the XlaTensorInfo for the given device memory pointer, creating one
+  // if necessary.
+  XlaTensorInfo* GetOrCreateTensorInfo(const Tensor& tensor);
+  // Returns the XlaTensorInfo for the device memory pointer extracted from
+  // tensor, creating one if necessary.
+  XlaTensorInfo* GetOrCreateTensorInfo(const void* device_ptr);
+
+  // Allocator interface
+  void DeallocateRaw(void* ptr) override;
+
+ private:
+  mutable mutex lock_;
+  // The managed tensor infos. The mapped value is a unique_ptr so that returned
+  // references are stable over rehashes.
+  std::unordered_map<const void*, std::unique_ptr<XlaTensorInfo>> tensor_infos_
+      GUARDED_BY(lock_);
+};
+}  // namespace tensorflow
+
+#endif
index 9ae4bb5..4d10f7e 100644 (file)
@@ -483,6 +483,8 @@ class Tensor {
   friend class TensorTestHelper;      // For access to set_shape
   friend class OpKernelContext;       // For access to RefCountIsOne().
   friend class ScopedAllocator;       // For access to buf_.
+  friend class XlaTensorBuffer;  // For access to the private constructor taking
+                                 // the buffer
   template <typename Device, typename T>
   friend class AssignVariableOp;  // For access to RefCountIsOne().
   template <typename Device, typename T>