[TF] Copy-on-write for Resource Variant assign op.
authorEugene Brevdo <ebrevdo@google.com>
Mon, 2 Apr 2018 21:48:21 +0000 (14:48 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 2 Apr 2018 21:51:18 +0000 (14:51 -0700)
PiperOrigin-RevId: 191351293

tensorflow/core/kernels/resource_variable_ops.cc

index d1675f2..082b57b 100644 (file)
@@ -252,6 +252,7 @@ class AssignVariableOp : public OpKernel {
     // tensor, we can just adopt the input tensor's buffer instead.
     std::unique_ptr<Tensor> input_alias =
         context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr);
+
     mutex_lock ml(*variable->mu());
     variable->is_initialized = true;
     if (input_alias) {
@@ -363,9 +364,33 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
                     DataTypeString(variable->tensor()->dtype()), " got ",
                     DataTypeString(DT_VARIANT)));
 
+    AllocatorAttributes attr;
+    attr.set_on_host(true);
+
+    // Copying is unnecessary if we are the last user of the value
+    // tensor, we can just adopt the input tensor's buffer instead.
+    // Note that Variant objects themselves always reside on host.
+    std::unique_ptr<Tensor> input_alias =
+        context->forward_input(1, DT_VARIANT, value.shape(), HOST_MEMORY, attr);
+
     mutex_lock ml(*variable->mu());
     variable->is_initialized = true;
-    *variable->tensor() = Tensor(DT_VARIANT, value.shape());
+    if (input_alias) {
+      *variable->tensor() = *input_alias;
+      return;
+    }
+
+    // Need to copy, but maybe we can re-use variable's buffer?
+    if (!variable->tensor()->RefCountIsOne() ||
+        !variable->tensor()->shape().IsSameSize(value.shape())) {
+      PersistentTensor unused;
+      Tensor* tmp;
+      OP_REQUIRES_OK(context,
+                     context->allocate_persistent(DT_VARIANT, value.shape(),
+                                                  &unused, &tmp, attr));
+      *variable->tensor() = *tmp;
+    }
+
     const auto elements_in = value.flat<Variant>();
     auto elements_out = variable->tensor()->flat<Variant>();
     auto copy_fn = std::bind(&VariantCopyFn<Device>, context,