)
cc_library(
+ name = "xla_tensor_info",
+ srcs = ["xla_tensor_info.cc"],
+ hdrs = ["xla_tensor_info.h"],
+ deps = [
+ ":common",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
name = "xla_device",
srcs = [
"xla_device.cc",
":common",
":jit_compilation_passes",
":xla_launch_util",
+ ":xla_tensor_info",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
deps = [
":common",
":xla_compilation_cache",
+ ":xla_tensor_info",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
+ const XlaDevice::Metadata* metadata;
+ Status s = XlaDevice::GetMetadata(ctx, &metadata);
+
+ XlaTensorInfoManager* tensor_info_manager = nullptr;
+ if (s.ok()) {
+ tensor_info_manager = &metadata->tensor_info_manager();
+ }
+
// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
- const XlaDevice::Metadata* metadata;
- Status s = XlaDevice::GetMetadata(ctx, &metadata);
if (s.ok()) {
platform_id_ = metadata->platform()->id();
}
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(num_resource_args_, client,
- &xla_allocator);
+ XlaComputationLaunchContext launch_context(
+ num_resource_args_, client, &xla_allocator, tensor_info_manager);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- launch_context.PopulateOutputs(ctx, kernel,
- run_result.ConsumeValueOrDie()->release());
+ launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
VLOG(1) << "Done";
}
(void)registrations;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(XlaDevice::Create(
- "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix,
- /*register_device_for_compilation=*/true, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+ DEVICE_CPU_XLA_JIT, options, name_prefix,
+ /*register_device_for_compilation=*/true,
+ /*transfer_as_literal=*/false, &device));
devices->push_back(device.release());
return Status::OK();
}
const string& platform_name, const string& device_name, int device_ordinal,
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix, bool register_device_for_compilation,
- std::unique_ptr<XlaDevice>* device) {
+ bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
device->reset(new XlaDevice(options, attrs, device_ordinal,
DeviceType(jit_device_name),
- platform.ValueOrDie()));
+ platform.ValueOrDie(), transfer_as_literal));
return Status::OK();
}
-XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
- const DeviceType& device_type)
+XlaDevice::Metadata::Metadata(
+ int device_ordinal, se::Platform* platform, const DeviceType& device_type,
+ std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager)
: device_ordinal_(device_ordinal),
device_type_(device_type),
- platform_(platform) {}
+ platform_(platform),
+ tensor_info_manager_(*tensor_info_manager) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
return device_type_;
}
+XlaTensorInfoManager& XlaDevice::Metadata::tensor_info_manager() const {
+ return *tensor_info_manager_;
+}
+
/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
const Metadata** metadata) {
XlaDevice* xla_device =
XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceAttributes& attrs, int device_ordinal,
- const DeviceType& jit_device_name, se::Platform* platform)
+ const DeviceType& jit_device_name, se::Platform* platform,
+ bool transfer_as_literal)
: LocalDevice(options, attrs),
- xla_metadata_(device_ordinal, platform, jit_device_name),
+ xla_metadata_(
+ device_ordinal, platform, jit_device_name,
+ // Pass tensor_info_manager_ by reference as it is initialized lazily.
+ &tensor_info_manager_),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
- platform_(platform) {}
+ platform_(platform),
+ tensor_info_manager_(nullptr),
+ transfer_as_literal_(transfer_as_literal) {}
XlaDevice::~XlaDevice() {}
xla::Backend* backend = client()->mutable_backend();
xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
backend, device_ordinal_);
+ tensor_info_manager_.reset(new XlaTensorInfoManager(xla_allocator_));
}
return xla_allocator_;
}
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- auto ctx = new XlaDeviceContext(stream);
+ // Call GetAllocator for the side-effect of ensuring the allocator and
+ // XlaTensorInfoManager is created.
+ (void)GetAllocator({});
+ auto ctx = new XlaDeviceContext(stream, tensor_info_manager_.get(),
+ transfer_as_literal_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- XlaTransferManager manager(stream);
+ XlaTransferManager manager(stream, tensor_info_manager_.get(),
+ transfer_as_literal_);
manager.CopyCPUTensorToDevice(&parsed, this, ©,
[&n, &status](const Status& s) {
status = s;
#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
class Metadata {
public:
Metadata(int device_ordinal, perftools::gputools::Platform* platform,
- const DeviceType& device_type);
+ const DeviceType& device_type,
+ std::unique_ptr<XlaTensorInfoManager>* tensor_info_manager);
// The index of the device on this host.
int device_ordinal() const;
perftools::gputools::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
+ XlaTensorInfoManager& tensor_info_manager() const;
private:
const int device_ordinal_;
const DeviceType device_type_;
perftools::gputools::Platform* platform_; // Not owned.
+ std::unique_ptr<XlaTensorInfoManager>& tensor_info_manager_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
// Factory function. 'platform_name' is the name of the XLA platform.
// 'device_name' is the name of the Tensorflow device to create.
// 'jit_device_name' is the name of the corresponding JIT device.
+ // 'transfer_as_literal' is true if device<->host transfers must be done using
+ // XLA's TransferLiteral{To,From}Device interface. If false, we can use
+ // ThenMemcpy instead.
static Status Create(const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
bool register_device_for_compilation,
+ bool transfer_as_literal,
std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- ::perftools::gputools::Platform* platform);
+ ::perftools::gputools::Platform* platform,
+ bool transfer_as_literal);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::Backend::StreamPtr stream_;
+ // Manages sideband data about tensors, in particular the on-device shape tree
+ // if the tensor requires multiple device buffers to represent (for example,
+ // tuple shapes).
+ // This is a unique_ptr because XlaTensorInfoManager is non-copy-constructible
+ // and we need to initialize this lazily (as we also lazily initialize the
+ // underlying allocator).
+ std::unique_ptr<XlaTensorInfoManager> tensor_info_manager_;
+ // Must we use XLA's transfer manager for correct host<->device transfers? if
+ // false, we can use ThenMemcpy() instead.
+ bool transfer_as_literal_;
};
// Builds dummy OpKernel registrations on 'device' for the JIT operators
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
-XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {}
+XlaTransferManager::XlaTransferManager(
+ se::Stream* stream, XlaTensorInfoManager* tensor_info_manager,
+ bool transfer_as_literal)
+ : stream_(stream),
+ tensor_info_manager_(tensor_info_manager),
+ transfer_as_literal_(transfer_as_literal) {}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes);
Status status;
- stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
+ if (transfer_as_literal_) {
+ status = xla::Unimplemented(
+ "XlaTransferManager::CopyCPUTensorToDevice not implemented for "
+ "literals");
+ } else {
+ stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s", stream_,
+ block_status.error_message().c_str());
+ }
}
done(status);
void* dst_ptr = DMAHelper::base(cpu_tensor);
Status status;
- stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
+ if (transfer_as_literal_) {
+ status = xla::Unimplemented(
+ "XlaTransferManager::CopyDeviceTensorToCPU not implemented for "
+ "literals");
+ } else {
+ stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s", stream_,
+ block_status.error_message().c_str());
+ }
}
done(status);
done(Status::OK());
}
-XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {}
+XlaDeviceContext::XlaDeviceContext(se::Stream* stream,
+ XlaTensorInfoManager* tensor_info_manager,
+ bool transfer_as_literal)
+ : manager_(stream, tensor_info_manager, transfer_as_literal) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
#include <memory>
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
- explicit XlaTransferManager(perftools::gputools::Stream* stream);
+ explicit XlaTransferManager(perftools::gputools::Stream* stream,
+ XlaTensorInfoManager* tensor_info_manager,
+ bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
perftools::gputools::Stream* stream_;
+ // The tensor info manager, for access to sideband information about tensors.
+ XlaTensorInfoManager* tensor_info_manager_;
+ // True if we must use XLA's TransferManager for correct device transfers.
+ bool transfer_as_literal_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
- explicit XlaDeviceContext(perftools::gputools::Stream* stream);
+ explicit XlaDeviceContext(perftools::gputools::Stream* stream,
+ XlaTensorInfoManager* tensor_info_manager,
+ bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
(void)registrations;
std::unique_ptr<XlaDevice> device;
- Status status = XlaDevice::Create(
- "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix,
- /*register_device_for_compilation=*/true, &device);
+ Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0,
+ DEVICE_GPU_XLA_JIT, options, name_prefix,
+ /*register_device_for_compilation=*/true,
+ /*transfer_as_literal=*/false, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;
OpKernelContext* op_context)
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
-XlaAllocator::~XlaAllocator() = default;
+XlaAllocator::~XlaAllocator() { CHECK(allocated_.empty()); }
xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
- AllocatorAttributes allocator_attrs;
- allocator_attrs.set_on_host(false);
-
- AllocationAttributes allocation_attrs;
- allocation_attrs.no_retry_on_failure = !retry_on_failure;
-
- Tensor t;
- Status status = op_context_->allocate_temp(
- DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
- allocation_attrs);
- if (!status.ok()) {
- VLOG(2) << "Allocation failed " << size;
- return status;
- }
- void* data =
- reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
- tensors_[data] = t;
+ void* data = op_context_->device()->GetAllocator({})->AllocateRaw(
+ Allocator::kAllocatorAlignment, size);
+ allocated_.insert(data);
return gpu::DeviceMemoryBase(data, size);
}
-Status XlaAllocator::RegisterArgument(const Tensor* t) {
- void* data =
- reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
- tensors_[data] = *t;
- return Status::OK();
-}
+void XlaAllocator::Release(void* ptr) { allocated_.erase(ptr); }
Status XlaAllocator::Deallocate(int device_ordinal,
gpu::DeviceMemoryBase* mem) {
- if (mem->opaque() != nullptr) {
- if (tensors_.erase(mem->opaque()) == 0) {
- return tensorflow::errors::InvalidArgument("Unknown tensor address");
- }
+ if (allocated_.count(mem->opaque())) {
+ op_context_->device()->GetAllocator({})->DeallocateRaw(mem->opaque());
+ allocated_.erase(mem->opaque());
}
return Status::OK();
}
-Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
- DataType dtype,
- const TensorShape& shape,
- Tensor* out_tensor) const {
- void* ptr = const_cast<void*>(buffer.opaque());
- auto it = tensors_.find(ptr);
- if (it == tensors_.end()) {
- return errors::InvalidArgument("Unknown tensor address");
- }
- const Tensor& tensor = it->second;
-
- int64 output_size = DataTypeSize(dtype) * shape.num_elements();
- if (tensor.TotalBytes() == output_size) {
- out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
- } else {
- Tensor slice = tensor.Slice(0, output_size);
- out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
- }
- return Status::OK();
+namespace {
+// Return the 'index''th subtree of the given ShapedBuffer as a ShapedBuffer.
+xla::ShapedBuffer ExtractSubShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
+ int index) {
+ xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape(
+ shaped_buffer.on_host_shape(), index);
+ xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape(
+ shaped_buffer.on_device_shape(), index);
+
+ xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape,
+ shaped_buffer.platform(),
+ shaped_buffer.device_ordinal());
+
+ auto& shape_tree = shaped_buffer.buffers();
+ auto& sub_shape_tree = sub_shaped_buffer.buffers();
+ sub_shape_tree.CopySubtreeFrom(shape_tree,
+ /*source_base_index=*/{index},
+ /*target_base_index=*/{});
+ return sub_shaped_buffer;
}
+} // namespace
XlaComputationLaunchContext::XlaComputationLaunchContext(
int64 num_resource_args, xla::LocalClient* client,
- XlaAllocator* xla_allocator)
+ XlaAllocator* xla_allocator, XlaTensorInfoManager* tensor_info_manager)
: num_resource_args_(num_resource_args),
client_(client),
- xla_allocator_(xla_allocator) {}
+ xla_allocator_(xla_allocator),
+ tensor_info_manager_(tensor_info_manager) {}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
t = &(ctx->input(arg_num));
}
- gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
- const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
-
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
- CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
- << "On-device shape "
- << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
- << " not the same as on-host shape "
- << xla::ShapeUtil::HumanStringWithLayout(shape);
- arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
- /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(),
- client_->default_device_ordinal());
- arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
- arg_ptrs_[i] = arg_buffers_[i].get();
-
- OP_REQUIRES_OK(ctx, xla_allocator_->RegisterArgument(t));
+ if (xla::ShapeUtil::IsTuple(on_device_shape)) {
+ CHECK(tensor_info_manager_);
+ const XlaTensorInfo* tensor_info =
+ tensor_info_manager_->GetTensorInfo(*t);
+ CHECK(tensor_info && tensor_info->has_shaped_buffer());
+ arg_ptrs_[i] =
+ const_cast<xla::ShapedBuffer*>(&tensor_info->shaped_buffer());
+ } else {
+ CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
+ << "On-device shape "
+ << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
+ << " not the same as on-host shape "
+ << xla::ShapeUtil::HumanStringWithLayout(shape);
+ gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
+ const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
+ arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
+ /*on_host_shape=*/shape, /*on_device_shape=*/shape,
+ client_->platform(), client_->default_device_ordinal());
+ arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
+ arg_ptrs_[i] = arg_buffers_[i].get();
+ }
}
}
void XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- std::unique_ptr<xla::ShapedBuffer> output) {
+ std::unique_ptr<xla::ScopedShapedBuffer> output) {
gpu::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
+ AllocatorAttributes alloc_attrs = ctx->output_alloc_attr(i);
+ Allocator* allocator = ctx->device()->GetAllocator(alloc_attrs);
+ if (tensor_info_manager_ && !alloc_attrs.on_host()) {
+ allocator = tensor_info_manager_;
+ }
if (kernel->outputs[i].is_constant) {
// Output is a constant.
const Tensor& const_tensor = kernel->outputs[i].constant_value;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
- Tensor output_tensor;
- // Looks up the owning Tensor by buffer address.
- OP_REQUIRES_OK(ctx, xla_allocator_->MakeTensorFromBuffer(
- buffer, ctx->expected_output_dtype(i), shape,
- &output_tensor));
+ Tensor output_tensor = XlaTensorBuffer::MakeTensor(
+ ctx->expected_output_dtype(i), shape, buffer, allocator);
+ xla_allocator_->Release(buffer.opaque());
+
+ xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape(
+ output->on_device_shape(), output_num);
+ if (xla::ShapeUtil::IsTuple(output_shape)) {
+ CHECK(tensor_info_manager_);
+ XlaTensorInfo* tensor_info =
+ tensor_info_manager_->GetOrCreateTensorInfo(output_tensor);
+ tensor_info->set_shaped_buffer(
+ ExtractSubShapedBuffer(*output, output_num));
+ }
ctx->set_output(i, output_tensor);
++output_num;
}
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
+ Allocator* allocator = ctx->device()->GetAllocator({});
+ if (tensor_info_manager_) {
+ allocator = tensor_info_manager_;
+ }
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
mutex_lock ml(*variable->mu());
OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
errors::Internal("Mismatched type in variable write"));
-
- // Looks up the owning Tensor by buffer address.
- OP_REQUIRES_OK(ctx,
- xla_allocator_->MakeTensorFromBuffer(
- buffer, write.type, write.shape, variable->tensor()));
+ *variable->tensor() =
+ XlaTensorBuffer::MakeTensor(write.type, write.shape, buffer, allocator);
+ xla_allocator_->Release(buffer.opaque());
+
+ xla::Shape output_shape = xla::ShapeUtil::GetTupleElementShape(
+ output->on_device_shape(), output_num);
+ if (xla::ShapeUtil::IsTuple(output_shape)) {
+ CHECK(tensor_info_manager_);
+ XlaTensorInfo* tensor_info =
+ tensor_info_manager_->GetOrCreateTensorInfo(*variable->tensor());
+ tensor_info->set_shaped_buffer(
+ ExtractSubShapedBuffer(*output, output_num));
+ }
++output_num;
}
}
#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
Status Deallocate(int device_ordinal,
perftools::gputools::DeviceMemoryBase* mem) override;
- // Register an Tensor (input or resource variable) with the allocator. If
- // the operation returns an alias to one of its inputs, then the allocator
- // needs to be able to handle it.
- Status RegisterArgument(const Tensor* t);
-
- // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
- // interpreted as having data type 'dtype' and shape 'shape'.
- Status MakeTensorFromBuffer(perftools::gputools::DeviceMemoryBase buffer,
- DataType dtype, const TensorShape& shape,
- Tensor* out_tensor) const;
+ // Un-track 'ptr' - do not delete it on destruction.
+ void Release(void* ptr);
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
private:
OpKernelContext* const op_context_;
-
- // Map from pointer address to the owning Tensor; used by
- // MakeTensorFromBuffer. Also used to automatically release Tensors when the
- // allocator is freed.
- std::unordered_map<void*, Tensor> tensors_;
+ std::unordered_set<void*> allocated_;
};
// Helper class to perform the marshalling of TensorFlow inputs and outputs to
class XlaComputationLaunchContext {
public:
XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
- XlaAllocator* xla_allocator);
+ XlaAllocator* xla_allocator,
+ XlaTensorInfoManager* tensor_info_manager);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
// Given the XLA output in `output`, populate all outputs of `ctx`.
void PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- std::unique_ptr<xla::ShapedBuffer> output);
+ std::unique_ptr<xla::ScopedShapedBuffer> output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
int64 num_resource_args_;
xla::LocalClient* client_;
XlaAllocator* xla_allocator_;
+ XlaTensorInfoManager* tensor_info_manager_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};
+// A simple TensorBuffer implementation that allows us to create Tensors that
+// take ownership of pre-allocated memory.
+class XlaTensorBuffer : public TensorBuffer {
+ public:
+ XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size,
+ Allocator* allocator)
+ : expected_size_(expected_size),
+ actual_size_(actual_size),
+ allocator_(allocator) {
+ data_ = const_cast<void*>(ptr);
+ }
+
+ ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); }
+
+ void* data() const override { return data_; }
+ size_t size() const override { return expected_size_; }
+
+ TensorBuffer* root_buffer() override { return this; }
+
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ proto->set_allocated_bytes(actual_size_);
+ }
+
+ static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
+ perftools::gputools::DeviceMemoryBase buffer,
+ Allocator* allocator) {
+ size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
+ auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
+ buffer.size(), allocator);
+ Tensor t(dtype, shape, tensor_buffer);
+ tensor_buffer->Unref();
+ return t;
+ }
+
+ private:
+ void* data_;
+ size_t expected_size_;
+ size_t actual_size_;
+ Allocator* allocator_;
+};
+
} // namespace tensorflow
#endif
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_tensor_info.h"
+
+namespace tensorflow {
+
+const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo(
+ const void* device_ptr) const {
+ mutex_lock lock(lock_);
+ auto iterator = tensor_infos_.find(device_ptr);
+ return (iterator == tensor_infos_.end()) ? nullptr
+ : tensor_infos_.at(device_ptr).get();
+}
+
+XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo(
+ const void* device_ptr) {
+ mutex_lock lock(lock_);
+ auto iterator = tensor_infos_.find(device_ptr);
+ if (iterator != tensor_infos_.end()) {
+ return iterator->second.get();
+ }
+ auto iterator_and_inserted =
+ tensor_infos_.emplace(device_ptr, MakeUnique<XlaTensorInfo>());
+ CHECK(iterator_and_inserted.second);
+ return iterator_and_inserted.first->second.get();
+}
+
+const XlaTensorInfo* XlaTensorInfoManager::GetTensorInfo(const Tensor& tensor) {
+ return GetTensorInfo(tensor.tensor_data().data());
+}
+
+XlaTensorInfo* XlaTensorInfoManager::GetOrCreateTensorInfo(
+ const Tensor& tensor) {
+ return GetOrCreateTensorInfo(tensor.tensor_data().data());
+}
+
+void XlaTensorInfoManager::DeallocateRaw(void* ptr) {
+ wrapped()->DeallocateRaw(ptr);
+ mutex_lock lock(lock_);
+ tensor_infos_.erase(ptr);
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_INFO_H_
+
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Information about a tensor. The XlaTensorInfoManager can maintain one of
+// these per device Tensor.
+class XlaTensorInfo {
+ public:
+ XlaTensorInfo() {}
+
+ // Some Tensors can have complex on-device shapes, including tuple shapes. To
+ // manage the memory for these tensors a ShapedBuffer may be required.
+
+ // Return true if this TensorInfo contains a ShapedBuffer.
+ bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
+ // Return the contained ShapedBuffer.
+ // REQUIRES: has_shaped_buffer()
+ const xla::ShapedBuffer& shaped_buffer() const { return *shaped_buffer_; }
+ // Mutates the TensorInfo to set the ShapedBuffer.
+ void set_shaped_buffer(xla::ShapedBuffer shaped_buffer) {
+ shaped_buffer_.reset(new xla::ShapedBuffer(std::move(shaped_buffer)));
+ }
+
+ private:
+ // The optional contained ShapedBuffer.
+ std::unique_ptr<xla::ShapedBuffer> shaped_buffer_;
+};
+
+// Manages XlaTensorInfo objects. This class is also an Allocator, so that
+// XlaTensorInfo objects can be deleted when their Tensor is deallocated.
+class XlaTensorInfoManager : public AllocatorWrapper {
+ public:
+ // Creates a new XlaTensorInfoManager, delegating all DeallocateRaw calls to
+ // allocator.
+ XlaTensorInfoManager(Allocator* allocator) : AllocatorWrapper(allocator) {}
+
+ // Returns the XlaTensorInfo for the given device memory pointer or nullptr if
+ // none exists.
+ const XlaTensorInfo* GetTensorInfo(const void* device_ptr) const;
+ // Returns the XlaTensorInfo for the device memory pointer extracted from
+ // tensor or nullptr if none exists.
+ const XlaTensorInfo* GetTensorInfo(const Tensor& tensor);
+
+ // Returns the XlaTensorInfo for the given device memory pointer, creating one
+ // if necessary.
+ XlaTensorInfo* GetOrCreateTensorInfo(const Tensor& tensor);
+ // Returns the XlaTensorInfo for the device memory pointer extracted from
+ // tensor, creating one if necessary.
+ XlaTensorInfo* GetOrCreateTensorInfo(const void* device_ptr);
+
+ // Allocator interface
+ void DeallocateRaw(void* ptr) override;
+
+ private:
+ mutable mutex lock_;
+ // The managed tensor infos. The mapped value is a unique_ptr so that returned
+ // references are stable over rehashes.
+ std::unordered_map<const void*, std::unique_ptr<XlaTensorInfo>> tensor_infos_
+ GUARDED_BY(lock_);
+};
+} // namespace tensorflow
+
+#endif
friend class TensorTestHelper; // For access to set_shape
friend class OpKernelContext; // For access to RefCountIsOne().
friend class ScopedAllocator; // For access to buf_.
+ friend class XlaTensorBuffer; // For access to the private constructor taking
+ // the buffer
template <typename Device, typename T>
friend class AssignVariableOp; // For access to RefCountIsOne().
template <typename Device, typename T>