[TF:XLA] Add direct implementation of AssignVariableOp for XLA devices.
authorPeter Hawkins <phawkins@google.com>
Sat, 26 May 2018 00:29:37 +0000 (17:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 26 May 2018 00:32:18 +0000 (17:32 -0700)
This allows us to avoid an XLA compilation and tensor copies when assigning to a variable placed on an XLA device.

PiperOrigin-RevId: 198127062

tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_device_context.h
tensorflow/compiler/jit/xla_device_ops.cc
tensorflow/compiler/jit/xla_device_ops.h
tensorflow/compiler/tests/variable_ops_test.py
tensorflow/compiler/tf2xla/kernels/variable_ops.cc

index c718125..71e63b1 100644 (file)
@@ -54,7 +54,13 @@ XlaTransferManager::XlaTransferManager(
       client_(client),
       transfer_manager_(client->backend().transfer_manager()),
       transfer_as_literal_(transfer_as_literal),
-      shape_representation_fn_(std::move(shape_representation_fn)) {}
+      shape_representation_fn_(std::move(shape_representation_fn)) {
+  if (!shape_representation_fn_) {
+    shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
+      return shape;
+    };
+  }
+}
 
 Status XlaTransferManager::TransferLiteralToDevice(
     const Tensor& host_tensor, Tensor* device_tensor) const {
@@ -113,13 +119,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
     XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
     CHECK(xla_tensor);
 
-    TensorShape shape;
-    if (shape_representation_fn_) {
-      shape = shape_representation_fn_(device_tensor->shape(),
-                                       device_tensor->dtype());
-    } else {
-      shape = device_tensor->shape();
-    }
+    TensorShape shape = shape_representation_fn_(device_tensor->shape(),
+                                                 device_tensor->dtype());
     if (!xla_tensor->has_shaped_buffer()) {
       Status s = xla_tensor->AllocateShapedBuffer(
           device_tensor->dtype(), shape, client_,
@@ -203,6 +204,42 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
   done(Status::OK());
 }
 
+void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+                                                  Tensor* dst_tensor,
+                                                  const StatusCallback& done) {
+  // TODO(phawkins): replace this code with an asynchronous implementation.
+  auto body = [&]() {
+    if (src_tensor.NumElements() == 0) {
+      return Status::OK();
+    }
+    XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
+    XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
+    CHECK(xla_src && xla_dst)
+        << "Missing destination tensor for device-to-device copy";
+    if (!xla_dst->has_shaped_buffer()) {
+      TensorShape shape =
+          shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+      TF_RETURN_IF_ERROR(
+          xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
+                                        stream_->parent()->device_ordinal()));
+    }
+    TF_RETURN_IF_ERROR(
+        xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
+            [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
+              const se::DeviceMemoryBase& from_buffer =
+                  xla_src->shaped_buffer().buffers().element(index);
+              CHECK_EQ(buffer->size(), from_buffer.size());
+              if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer,
+                                                        buffer->size())) {
+                return errors::Internal("Device to device memcpy failed");
+              }
+              return Status::OK();
+            }));
+    return Status::OK();
+  };
+  done(body());
+}
+
 XlaDeviceContext::XlaDeviceContext(
     se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
     XlaCompiler::ShapeRepresentationFn shape_representation_fn)
@@ -224,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
                                  done);
 }
 
+void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
+                                                Tensor* dst_tensor,
+                                                const StatusCallback& done) {
+  manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done);
+}
+
 }  // namespace tensorflow
index 9af9655..ee346e5 100644 (file)
@@ -55,6 +55,10 @@ class XlaTransferManager {
   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
                              StringPiece tensor_name, Device* device,
                              Tensor* cpu_tensor, StatusCallback done);
+
+  void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+                                const StatusCallback& done);
+
   se::Stream* stream() const { return stream_; }
 
  private:
@@ -72,7 +76,7 @@ class XlaTransferManager {
   xla::TransferManager* transfer_manager_;
   // True if we must use XLA's TransferManager for correct device transfers.
   const bool transfer_as_literal_;
-  const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+  XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
 };
 
 // DeviceContext for operators assigned to XlaDevice devices. The
@@ -90,6 +94,9 @@ class XlaDeviceContext : public DeviceContext {
   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
                              StringPiece tensor_name, Device* device,
                              Tensor* cpu_tensor, StatusCallback done) override;
+  void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
+                                const StatusCallback& done);
+
   se::Stream* stream() const override { return manager_.stream(); }
 
  private:
index f68dba6..5ecb1af 100644 (file)
@@ -15,7 +15,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/jit/xla_device_ops.h"
 
+#include <memory>
+
 #include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/jit/xla_tensor.h"
 
 namespace tensorflow {
 
@@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
              << type_string() << " on an XLA device. This should never happen.";
 }
 
+XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c)
+    : AsyncOpKernel(c) {
+  OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
+}
+
+void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context,
+                                       DoneCallback done) {
+  OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(),
+                    errors::InvalidArgument(
+                        "Variable and value dtypes don't match; respectively, ",
+                        dtype_, " and ", context->input(1).dtype()),
+                    done);
+  Var* variable = nullptr;
+  OP_REQUIRES_OK_ASYNC(
+      context,
+      LookupOrCreateResource<Var>(
+          context, HandleFromInput(context, 0), &variable,
+          [this, context](Var** ptr) {
+            *ptr = new Var(dtype_);
+            PersistentTensor unused;
+            Tensor* tmp;
+            AllocatorAttributes attr;
+            TF_RETURN_IF_ERROR(context->allocate_persistent(
+                dtype_, context->input(1).shape(), &unused, &tmp, attr));
+            *(*ptr)->tensor() = *tmp;
+            return Status::OK();
+          }),
+      done);
+  core::ScopedUnref s(variable);
+
+  OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_,
+                    errors::InvalidArgument(
+                        "Trying to assign variable with wrong dtype. Expected ",
+                        DataTypeString(variable->tensor()->dtype()), " got ",
+                        DataTypeString(dtype_)),
+                    done);
+
+  const Tensor& value = context->input(1);
+  AllocatorAttributes attr;
+
+  // Copying is unnecessary if we are the last user of the value tensor, we can
+  // just adopt the input tensor's buffer instead.
+  std::unique_ptr<Tensor> input_alias = context->forward_input(
+      1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_,
+      value.shape(), DEVICE_MEMORY, attr);
+  mutex_lock ml(*variable->mu());
+  variable->is_initialized = true;
+  if (input_alias) {
+    *variable->tensor() = *input_alias;
+    done();
+    return;
+  }
+
+  // Need to copy, but maybe we can re-use variable's buffer?
+  if (!XlaTensor::RefCountIsOne(*variable->tensor()) ||
+      !variable->tensor()->shape().IsSameSize(value.shape())) {
+    // Copy to new buffer
+    PersistentTensor unused;
+    Tensor* tmp;
+    OP_REQUIRES_OK_ASYNC(context,
+                         context->allocate_persistent(dtype_, value.shape(),
+                                                      &unused, &tmp, attr),
+                         done);
+    *variable->tensor() = *tmp;
+  }
+
+  XlaDeviceContext* device_context =
+      static_cast<XlaDeviceContext*>(context->op_device_context());
+
+  variable->Ref();
+  device_context->CopyDeviceTensorToDevice(
+      value, variable->tensor(), [context, variable, done](Status status) {
+        variable->Unref();
+        context->SetStatus(status);
+        done();
+      });
+}
+
 }  // namespace tensorflow
index 5363257..b27c32e 100644 (file)
@@ -42,6 +42,15 @@ class XlaDeviceDummyOp : public OpKernel {
   void Compute(OpKernelContext* ctx) override;
 };
 
+class XlaAssignVariableOp : public AsyncOpKernel {
+ public:
+  explicit XlaAssignVariableOp(OpKernelConstruction* c);
+  void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+  DataType dtype_;
+};
+
 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
   REGISTER_KERNEL_BUILDER(Name("XlaLaunch")               \
                               .Device(DEVICE)             \
@@ -78,6 +87,9 @@ class XlaDeviceDummyOp : public OpKernel {
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"),            \
       ReadVariableOp);                                                         \
+  REGISTER_KERNEL_BUILDER(                                                     \
+      Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"),          \
+      XlaAssignVariableOp);                                                    \
   REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE),               \
                           ControlTriggerOp);                                   \
   REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"),    \
index 8ecad00..2c09b03 100644 (file)
@@ -187,6 +187,25 @@ class VariableOpsTest(XLATestCase):
           rtol=1e-4)
       self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4)
 
+  def testWriteOfAliasedTensor(self):
+    for dtype in self.numeric_types:
+      init = np.array([[1, 2j], [3, 4]]).astype(dtype)
+      update = np.array([[7, 1j], [2, 11]]).astype(dtype)
+      with self.test_session() as sess, self.test_scope():
+        v = resource_variable_ops.ResourceVariable(init)
+        sess.run(variables.variables_initializer([v]))
+        p = array_ops.placeholder(dtype)
+        q = array_ops.identity(p)
+        x = v.read_value()
+        # Writes the value of 'p' to 'v', but keeps a reference to the original
+        # value of 'v' so the variable update cannot reuse its buffer.
+        with ops.control_dependencies([x]):
+          y = v.assign(q)
+        result = sess.run([x, y, q], {p: update})
+        self.assertAllClose(init, result[0])
+        self.assertAllClose(update, result[1])
+        self.assertAllClose(update, result[2])
+
 
 class StridedSliceAssignChecker(object):
   """Compares the results of a slice assignment using Tensorflow and numpy."""
index 631cd44..a163fa0 100644 (file)
@@ -67,7 +67,7 @@ class AssignVariableOp : public XlaOpKernel {
                    ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
   }
 };
-REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp);
+REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp);
 
 class AssignAddVariableOp : public XlaOpKernel {
  public: