client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ shape_representation_fn_(std::move(shape_representation_fn)) {
+ if (!shape_representation_fn_) {
+ shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
+ return shape;
+ };
+ }
+}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
- TensorShape shape;
- if (shape_representation_fn_) {
- shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
- } else {
- shape = device_tensor->shape();
- }
+ TensorShape shape = shape_representation_fn_(device_tensor->shape(),
+ device_tensor->dtype());
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
done(Status::OK());
}
+void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+ Tensor* dst_tensor,
+ const StatusCallback& done) {
+ // TODO(phawkins): replace this code with an asynchronous implementation.
+ auto body = [&]() {
+ if (src_tensor.NumElements() == 0) {
+ return Status::OK();
+ }
+ XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
+ XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
+ CHECK(xla_src && xla_dst)
+ << "Missing destination tensor for device-to-device copy";
+ if (!xla_dst->has_shaped_buffer()) {
+ TensorShape shape =
+ shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+ TF_RETURN_IF_ERROR(
+ xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
+ stream_->parent()->device_ordinal()));
+ }
+ TF_RETURN_IF_ERROR(
+ xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
+ [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
+ const se::DeviceMemoryBase& from_buffer =
+ xla_src->shaped_buffer().buffers().element(index);
+ CHECK_EQ(buffer->size(), from_buffer.size());
+ if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
+ buffer->size())) {
+ return errors::Internal("Device to device memcpy failed");
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+ };
+ done(body());
+}
+
XlaDeviceContext::XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
done);
}
+void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+ Tensor* dst_tensor,
+ const StatusCallback& done) {
+ manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done);
+}
+
} // namespace tensorflow
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
+
+ void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+ const StatusCallback& done);
+
se::Stream* stream() const { return stream_; }
private:
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
- const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
+ void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+ const StatusCallback& done);
+
se::Stream* stream() const override { return manager_.stream(); }
private:
#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/jit/xla_tensor.h"
namespace tensorflow {
<< type_string() << " on an XLA device. This should never happen.";
}
+XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c)
+ : AsyncOpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+}
+
+void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context,
+ DoneCallback done) {
+ OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(),
+ errors::InvalidArgument(
+ "Variable and value dtypes don't match; respectively, ",
+ dtype_, " and ", context->input(1).dtype()),
+ done);
+ Var* variable = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, context](Var** ptr) {
+ *ptr = new Var(dtype_);
+ PersistentTensor unused;
+ Tensor* tmp;
+ AllocatorAttributes attr;
+ TF_RETURN_IF_ERROR(context->allocate_persistent(
+ dtype_, context->input(1).shape(), &unused, &tmp, attr));
+ *(*ptr)->tensor() = *tmp;
+ return Status::OK();
+ }),
+ done);
+ core::ScopedUnref s(variable);
+
+ OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_,
+ errors::InvalidArgument(
+ "Trying to assign variable with wrong dtype. Expected ",
+ DataTypeString(variable->tensor()->dtype()), " got ",
+ DataTypeString(dtype_)),
+ done);
+
+ const Tensor& value = context->input(1);
+ AllocatorAttributes attr;
+
+ // Copying is unnecessary if we are the last user of the value tensor, we can
+ // just adopt the input tensor's buffer instead.
+ std::unique_ptr<Tensor> input_alias = context->forward_input(
+ 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_,
+ value.shape(), DEVICE_MEMORY, attr);
+ mutex_lock ml(*variable->mu());
+ variable->is_initialized = true;
+ if (input_alias) {
+ *variable->tensor() = *input_alias;
+ done();
+ return;
+ }
+
+ // Need to copy, but maybe we can re-use variable's buffer?
+ if (!XlaTensor::RefCountIsOne(*variable->tensor()) ||
+ !variable->tensor()->shape().IsSameSize(value.shape())) {
+ // Copy to new buffer
+ PersistentTensor unused;
+ Tensor* tmp;
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_persistent(dtype_, value.shape(),
+ &unused, &tmp, attr),
+ done);
+ *variable->tensor() = *tmp;
+ }
+
+ XlaDeviceContext* device_context =
+ static_cast<XlaDeviceContext*>(context->op_device_context());
+
+ variable->Ref();
+ device_context->CopyDeviceTensorToDevice(
+ value, variable->tensor(), [context, variable, done](Status status) {
+ variable->Unref();
+ context->SetStatus(status);
+ done();
+ });
+}
+
} // namespace tensorflow
void Compute(OpKernelContext* ctx) override;
};
+class XlaAssignVariableOp : public AsyncOpKernel {
+ public:
+ explicit XlaAssignVariableOp(OpKernelConstruction* c);
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+ DataType dtype_;
+};
+
#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \
.Device(DEVICE) \
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
+ XlaAssignVariableOp); \
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
ControlTriggerOp); \
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
rtol=1e-4)
self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
+ def testWriteOfAliasedTensor(self):
+ for dtype in self.numeric_types:
+ init = np.array([[1, 2j], [3, 4]]).astype(dtype)
+ update = np.array([[7, 1j], [2, 11]]).astype(dtype)
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(init)
+ sess.run(variables.variables_initializer([v]))
+ p = array_ops.placeholder(dtype)
+ q = array_ops.identity(p)
+ x = v.read_value()
+ # Writes the value of 'p' to 'v', but keeps a reference to the original
+ # value of 'v' so the variable update cannot reuse its buffer.
+ with ops.control_dependencies([x]):
+ y = v.assign(q)
+ result = sess.run([x, y, q], {p: update})
+ self.assertAllClose(init, result[0])
+ self.assertAllClose(update, result[1])
+ self.assertAllClose(update, result[2])
+
class StridedSliceAssignChecker(object):
"""Compares the results of a slice assignment using Tensorflow and numpy."""
ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
}
};
-REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp);
+REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp);
class AssignAddVariableOp : public XlaOpKernel {
public: