Before this change, a Tensor contained a device pointer and a TensorInfoManager datas...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 30 Mar 2018 09:41:35 +0000 (02:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 30 Mar 2018 09:44:15 +0000 (02:44 -0700)
commit36b4744977d75b75bc969dfd25f8951f1141caa6
tree1bcd374fe8d4da8f5d27ae9374f66029d8038c9c
parent17d05bb44cecea741e8a4f7adb8c107af06893b4
Before this change, a Tensor contained a device pointer and a TensorInfoManager datastructure contained a mapping from device pointer to XlaTensorInfo object. This TensorInfoManager needed to be an Allocator too, so it could be informed when a Tensor is released.

After this change, a Tensor on an XlaDevice contains an XlaTensor object. The XlaTensor object is the equivalent of the old XlaTensorInfo object.

This has advantages and drawbacks:
  + We don't need yet another allocator wrapper, as there is no side-band data to manage.
  + No hashtable lookups are required.
  - As XlaLocalLaunchOp could either be on an XlaDevice or a TF-classic device, we need some way to distinguish whether a Tensor is a TF-classic tensor (holds a device pointer) or an XlaTensor (we use a tagged pointer).

As part of this, allocate ShapedBuffers using the XLA backend's allocator directly instead of a roundabout route where we:
  Wrapped the XLA allocator in an XlaDeviceAllocator
  Then wrapped the XlaDeviceAllocator in an XlaAllocator

This leaves less to go wrong. Ideally we'd actually use StreamExecutor's allocator here, but this is less useful than XLA's as it doesn't provide helpful OOM messages (just returns nullptr).

PiperOrigin-RevId: 191048184
12 files changed:
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_compile_on_demand_op.cc
tensorflow/compiler/jit/xla_device.cc
tensorflow/compiler/jit/xla_device.h
tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_device_context.h
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/jit/xla_launch_util.h
tensorflow/compiler/jit/xla_tensor.cc [new file with mode: 0644]
tensorflow/compiler/jit/xla_tensor.h [moved from tensorflow/compiler/jit/xla_tensor_info.h with 50% similarity]
tensorflow/compiler/jit/xla_tensor_info.cc [deleted file]