Move inplace update operators.
authorPatrick Nguyen <drpng@google.com>
Fri, 6 Apr 2018 23:33:11 +0000 (16:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 23:35:38 +0000 (16:35 -0700)
The ops are not part of the public API.

PiperOrigin-RevId: 191957660

18 files changed:
tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_Empty.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_Empty.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/inplace_ops.cc
tensorflow/core/kernels/inplace_ops_functor.h
tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
tensorflow/core/ops/array_ops.cc
tensorflow/python/BUILD
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/inplace_ops_test.py [new file with mode: 0644]
tensorflow/python/ops/inplace_ops.py [new file with mode: 0644]

diff --git a/tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt
new file mode 100644 (file)
index 0000000..fe0fc38
--- /dev/null
@@ -0,0 +1,15 @@
+op {
+  graph_op_name: "DeepCopy"
+  in_arg {
+    name: "x"
+    description: "The source tensor of type `T`."
+  }
+  out_arg {
+    name: "y"
+    description: <<END
+    y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
+      is not an alias of `x`.
+END
+  }
+  summary: "Makes a copy of `x`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Empty.pbtxt b/tensorflow/core/api_def/base_api/api_def_Empty.pbtxt
new file mode 100644 (file)
index 0000000..746f561
--- /dev/null
@@ -0,0 +1,23 @@
+op {
+  graph_op_name: "Empty"
+  in_arg {
+    name: "shape"
+    description: "1-D. Represents the shape of the output tensor."
+  }
+  attr {
+    name: "init"
+    description:
+        "If True, initialize the returned tensor with the default value "
+        "of dtype.  Otherwise, the implementation is free not to initialize"
+        "the tensor's content."
+  }
+  out_arg {
+    name: "output"
+    description: "A `Tensor` of type `T`."
+  }
+  summary: <<END
+Creates a tensor with the given shape.
+
+This operation creates a tensor of `shape` and `dtype`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
new file mode 100644 (file)
index 0000000..3654286
--- /dev/null
@@ -0,0 +1,28 @@
+op {
+  graph_op_name: "InplaceAdd"
+  in_arg {
+    name: "x"
+    description: "A `Tensor` of type T."
+  }
+  in_arg {
+    name: "i"
+    description: "A vector. Indices into the left-most dimension of `x`."
+  }
+  in_arg {
+    name: "v"
+    description:
+        "A `Tensor` of type T. Same dimension sizes as x except "
+        "the first dimension, which must be the same as i's size."
+  }
+  out_arg {
+    name: "y"
+    description:
+        "A `Tensor` of type T. An alias of `x`. The content "
+        "of `y` is undefined if there are duplicates in `i`."
+  }
+  summary: <<END
+    Adds v into specified rows of x.
+
+    Computes y = x; y[i, :] += v; return y.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt
new file mode 100644 (file)
index 0000000..a9480b4
--- /dev/null
@@ -0,0 +1,28 @@
+op {
+  graph_op_name: "InplaceSub"
+  in_arg {
+    name: "x"
+    description: "A `Tensor` of type T."
+  }
+  in_arg {
+    name: "i"
+    description: "A vector. Indices into the left-most dimension of `x`."
+  }
+  in_arg {
+    name: "v"
+    description:
+        "A `Tensor` of type T. Same dimension sizes as x except "
+        "the first dimension, which must be the same as i's size."
+  }
+  out_arg {
+    name: "y"
+    description:
+        "A `Tensor` of type T. An alias of `x`. The content "
+        "of `y` is undefined if there are duplicates in `i`."
+  }
+  summary: <<END
+    Subtracts `v` into specified rows of `x`.
+
+    Computes y = x; y[i, :] -= v; return y.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt
new file mode 100644 (file)
index 0000000..2fcd365
--- /dev/null
@@ -0,0 +1,28 @@
+op {
+  graph_op_name: "InplaceUpdate"
+  in_arg {
+    name: "x"
+    description: "A tensor of type `T`."
+  }
+  in_arg {
+    name: "i"
+    description: "A vector. Indices into the left-most dimension of `x`."
+  }
+  in_arg {
+    name: "v"
+    description:
+        "A `Tensor` of type T. Same dimension sizes as x except "
+        "the first dimension, which must be the same as i's size."
+  }
+  out_arg {
+    name: "y"
+    description:
+        "A `Tensor` of type T. An alias of `x`. The content "
+        "of `y` is undefined if there are duplicates in `i`."
+  }
+  summary: <<END
+    Updates specified rows with values in `v`.
+
+    Computes `x[i, :] = v; return x`.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt b/tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt
new file mode 100644 (file)
index 0000000..2d5ed2b
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "DeepCopy"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Empty.pbtxt b/tensorflow/core/api_def/python_api/api_def_Empty.pbtxt
new file mode 100644 (file)
index 0000000..0b86352
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Empty"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt
new file mode 100644 (file)
index 0000000..390e3bb
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "InplaceAdd"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt
new file mode 100644 (file)
index 0000000..af9634f
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "InplaceSub"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt
new file mode 100644 (file)
index 0000000..5fa9d77
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "InplaceUpdate"
+  visibility: HIDDEN
+}
index a71d047..ef6ce05 100644 (file)
@@ -213,13 +213,13 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
 
 typedef Eigen::GpuDevice GPUDevice;
 
-#define REGISTER_EMPTY(type)                                  \
+#define REGISTER_PARALLEL_CONCAT_START(type)                  \
   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart")        \
                               .Device(DEVICE_GPU)             \
                               .TypeConstraint<type>("dtype"), \
                           ParallelConcatStart<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY)
-#undef REGISTER_EMPTY
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT_START)
+#undef REGISTER_PARALLEL_CONCAT_START
 
 #define REGISTER_PARALLEL_CONCAT(type)                                     \
   REGISTER_KERNEL_BUILDER(                                                 \
@@ -248,5 +248,295 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
                         ParallelConcatUpdate<CPUDevice>);
 #endif
 
+class InplaceOpBase : public OpKernel {
+ public:
+  explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    auto x = ctx->input(0);
+    auto i = ctx->input(1);
+    auto v = ctx->input(2);
+
+    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(i.shape()),
+                errors::InvalidArgument("i must be a vector. ",
+                                        i.shape().DebugString()));
+    OP_REQUIRES(ctx, x.dims() == v.dims(),
+                errors::InvalidArgument(
+                    "x and v shape doesn't match (ranks differ): ",
+                    x.shape().DebugString(), " vs. ", v.shape().DebugString()));
+    for (int i = 1; i < x.dims(); ++i) {
+      OP_REQUIRES(
+          ctx, x.dim_size(i) == v.dim_size(i),
+          errors::InvalidArgument("x and v shape doesn't match at index ", i,
+                                  " : ", x.shape().DebugString(), " vs. ",
+                                  v.shape().DebugString()));
+    }
+    OP_REQUIRES(ctx, i.dim_size(0) == v.dim_size(0),
+                errors::InvalidArgument(
+                    "i and x shape doesn't match at index 0: ",
+                    i.shape().DebugString(), " vs. ", v.shape().DebugString()));
+
+    Tensor y = x;  // This creates an alias intentionally.
+    OP_REQUIRES_OK(ctx, DoCompute(ctx, i, v, &y));
+    ctx->set_output(0, y);
+  }
+
+ protected:
+  virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i,
+                           const Tensor& v, Tensor* y) = 0;
+};
+
+}  // end namespace
+
+namespace functor {
+
+template <typename T>
+void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i,
+                 const Tensor& v, Tensor* y) {
+  auto Ti = i.flat<int32>();
+  auto Tv = v.flat_outer_dims<T>();
+  auto Ty = y->flat_outer_dims<T>();
+  auto nrows = Ty.dimension(0);
+  for (int64 j = 0; j < Ti.size(); ++j) {
+    auto r = (Ti(j) % nrows + nrows) % nrows;  // Guard index range.
+    switch (op) {
+      case I_UPDATE:
+        Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
+        break;
+      case I_ADD:
+        Ty.template chip<0>(r).device(d) += Tv.template chip<0>(j);
+        break;
+      case I_SUB:
+        Ty.template chip<0>(r).device(d) -= Tv.template chip<0>(j);
+        break;
+    }
+  }
+}
+
+// String type only supports inplace update.
+void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i,
+                             const Tensor& v, Tensor* y) {
+  auto Ti = i.flat<int32>();
+  auto Tv = v.flat_outer_dims<string>();
+  auto Ty = y->flat_outer_dims<string>();
+  auto nrows = Ty.dimension(0);
+  for (int64 j = 0; j < Ti.size(); ++j) {
+    auto r = (Ti(j) % nrows + nrows) % nrows;  // Guard index range.
+    Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
+  }
+}
+
+template <>
+Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i,
+                 const Tensor& v, Tensor* y) {
+  CHECK_EQ(v.dtype(), y->dtype());
+  if (op == I_UPDATE) {
+    if (v.dtype() == DT_STRING) {
+      DoInplaceStringUpdateOp(device, i, v, y);
+      return Status::OK();
+    } else if (v.dtype() == DT_BOOL) {
+      DoInplaceOp<bool>(device, op, i, v, y);
+      return Status::OK();
+    }
+  }
+  switch (v.dtype()) {
+#define CASE(type)                          \
+  case DataTypeToEnum<type>::value:         \
+    DoInplaceOp<type>(device, op, i, v, y); \
+    break;
+    TF_CALL_NUMBER_TYPES(CASE);
+#undef CASE
+    default:
+      return errors::InvalidArgument("Unsupported data type: ", v.dtype());
+  }
+  return Status::OK();
+}
+
+}  // end namespace functor
+
+namespace {
+template <typename Device, functor::InplaceOpType op>
+class InplaceOp : public InplaceOpBase {
+ public:
+  explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {}
+
+ protected:
+  Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v,
+                   Tensor* y) override {
+    const auto& d = ctx->eigen_device<Device>();
+    return ::tensorflow::functor::DoInplace(d, op, i, v, y);
+  }
+};
+
+class CopyOpBase : public OpKernel {
+ public:
+  explicit CopyOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    auto x = ctx->input(0);
+    Tensor* y;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
+    OP_REQUIRES_OK(ctx, DoCompute(ctx, x, y));
+  }
+
+ protected:
+  virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x,
+                           Tensor* y) = 0;
+};
+
+template <typename Device>
+class CopyOp : public CopyOpBase {
+ public:
+  explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {}
+
+ protected:
+  Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override {
+    const auto& d = ctx->eigen_device<Device>();
+    return ::tensorflow::functor::DoCopy(d, x, y);
+  }
+};
+
+}  // end namespace
+
+namespace functor {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <>
+Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) {
+  CHECK_EQ(x.dtype(), y->dtype());
+  switch (x.dtype()) {
+#define CASE(type)                                   \
+  case DataTypeToEnum<type>::value:                  \
+    y->flat<type>().device(device) = x.flat<type>(); \
+    break;
+
+    TF_CALL_NUMBER_TYPES(CASE);
+    TF_CALL_bool(CASE);
+#undef CASE
+    default:
+      return errors::InvalidArgument("Unsupported data type: ", x.dtype());
+  }
+  return Status::OK();
+}
+
+}  // end namespace functor
+
+namespace {
+template <typename Device, typename T>
+class EmptyOp : public OpKernel {
+ public:
+  explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& shape = ctx->input(0);
+    OP_REQUIRES(
+        ctx, TensorShapeUtils::IsVector(shape.shape()),
+        errors::InvalidArgument("shape must be a vector of int32, got shape ",
+                                shape.shape().DebugString()));
+    auto dims = shape.flat<int32>();
+    TensorShape out_shape;
+    OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
+                            reinterpret_cast<const int32*>(dims.data()),
+                            dims.size(), &out_shape));
+    Tensor* out = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
+
+    if (init_) {
+      functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
+                                           out->flat<T>());
+    }
+  }
+
+ private:
+  bool init_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("InplaceUpdate").Device(DEVICE_CPU),
+                        InplaceOp<CPUDevice, functor::I_UPDATE>);
+REGISTER_KERNEL_BUILDER(Name("InplaceAdd").Device(DEVICE_CPU),
+                        InplaceOp<CPUDevice, functor::I_ADD>);
+REGISTER_KERNEL_BUILDER(Name("InplaceSub").Device(DEVICE_CPU),
+                        InplaceOp<CPUDevice, functor::I_SUB>);
+REGISTER_KERNEL_BUILDER(Name("DeepCopy").Device(DEVICE_CPU), CopyOp<CPUDevice>);
+
+#define REGISTER_EMPTY(type, dev)                             \
+  REGISTER_KERNEL_BUILDER(Name("Empty")                       \
+                              .Device(DEVICE_##dev)           \
+                              .HostMemory("shape")            \
+                              .TypeConstraint<type>("dtype"), \
+                          EmptyOp<dev##Device, type>)
+
+REGISTER_EMPTY(float, CPU)
+REGISTER_EMPTY(double, CPU)
+REGISTER_EMPTY(Eigen::half, CPU)
+REGISTER_EMPTY(string, CPU)
+REGISTER_EMPTY(int32, CPU)
+REGISTER_EMPTY(int64, CPU)
+REGISTER_EMPTY(bool, CPU)
+
+#if GOOGLE_CUDA
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER(TYPE)                                                    \
+  REGISTER_KERNEL_BUILDER(                                                \
+      Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+      InplaceOp<GPUDevice, functor::I_UPDATE>);                           \
+  REGISTER_KERNEL_BUILDER(                                                \
+      Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),    \
+      InplaceOp<GPUDevice, functor::I_ADD>);                              \
+  REGISTER_KERNEL_BUILDER(                                                \
+      Name("InplaceSub").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),    \
+      InplaceOp<GPUDevice, functor::I_SUB>);                              \
+  REGISTER_KERNEL_BUILDER(                                                \
+      Name("DeepCopy").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"),      \
+      CopyOp<GPUDevice>);
+
+REGISTER(float);
+REGISTER(double);
+REGISTER(Eigen::half);
+REGISTER(int64);
+
+REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("x")
+                            .HostMemory("i")
+                            .HostMemory("v")
+                            .HostMemory("y")
+                            .TypeConstraint<int32>("T"),
+                        InplaceOp<CPUDevice, functor::I_UPDATE>);
+REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("x")
+                            .HostMemory("i")
+                            .HostMemory("v")
+                            .HostMemory("y")
+                            .TypeConstraint<int32>("T"),
+                        InplaceOp<CPUDevice, functor::I_ADD>);
+REGISTER_KERNEL_BUILDER(Name("InplaceSub")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("x")
+                            .HostMemory("i")
+                            .HostMemory("v")
+                            .HostMemory("y")
+                            .TypeConstraint<int32>("T"),
+                        InplaceOp<CPUDevice, functor::I_SUB>);
+
+REGISTER_KERNEL_BUILDER(Name("DeepCopy")
+                            .Device(DEVICE_GPU)
+                            .HostMemory("x")
+                            .HostMemory("y")
+                            .TypeConstraint<int32>("T"),
+                        CopyOp<CPUDevice>);
+REGISTER_EMPTY(float, GPU);
+REGISTER_EMPTY(double, GPU);
+REGISTER_EMPTY(Eigen::half, GPU);
+REGISTER_EMPTY(int64, GPU);
+
+#endif  // GOOGLE_CUDA
+
 }  // end namespace
 }  // end namespace tensorflow
index 53529f5..b806787 100644 (file)
@@ -26,6 +26,23 @@ template <typename Device>
 Status DoParallelConcat(const Device& device, const Tensor& value, int32 loc,
                         Tensor* output);
 
+// Inplace update/add/sub values in 'y'. It computes
+//   y[i, :] = v if op is I_UPDATE
+//   y[i, :] += v if op is I_ADD
+//   y[i, :] -= v if op is I_SUB
+// Returns an error if the operation fails.
+enum InplaceOpType {
+  I_UPDATE,  // x = y
+  I_ADD,     // x += y
+  I_SUB,     // x -= y
+};
+template <typename Device>
+Status DoInplace(const Device& device, InplaceOpType op, const Tensor& i,
+                 const Tensor& v, Tensor* y);
+// Copies x into y.
+template <typename Device>
+Status DoCopy(const Device& device, const Tensor& x, Tensor* y);
+
 }  // end namespace functor
 }  // end namespace tensorflow
 
index 8467360..f1616b1 100644 (file)
@@ -77,6 +77,103 @@ Status DoParallelConcat(const Device& d, const Tensor& value, int32 loc,
   return Status::OK();
 }
 
+template <typename T, InplaceOpType op>
+__global__ void DoInplaceOpKernel(int nthreads, const int64 rows,
+                                  const int64 cols, const int64 n, const T* src,
+                                  const int32* rowids, T* dst) {
+  CUDA_1D_KERNEL_LOOP(idx, nthreads) {
+    int64 r = idx / cols;
+    int64 c = idx % cols;
+    r = (rowids[r] % rows + rows) % rows;  // Guard index range.
+    T* p = dst + r * cols + c;
+    const T* q = src + idx;
+    switch (op) {
+      case I_UPDATE:
+        *p = ldg(q);
+        break;
+      case I_ADD:
+        *p += ldg(q);
+        break;
+      case I_SUB:
+        *p -= ldg(q);
+        break;
+    }
+  }
+}
+
+template <typename T>
+void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
+                 const Tensor& v, Tensor* y) {
+  const int64 nelem = v.NumElements();
+  CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
+  auto Ty = y->flat_outer_dims<T>();
+  const int64 nrows = Ty.dimension(0);
+  const int64 ncols = Ty.dimension(1);
+  const int64 n = i.NumElements();
+  const T* src = v.flat<T>().data();
+  // TODO(sjhwang): Check that first dimension fits in int32 range.
+  const int32* rowids = i.flat<int32>().data();
+  T* dst = y->flat<T>().data();
+  switch (op) {
+    case I_UPDATE:
+      DoInplaceOpKernel<T, I_UPDATE>
+          <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+              cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+      break;
+    case I_ADD:
+      DoInplaceOpKernel<T, I_ADD>
+          <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+              cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+      break;
+    case I_SUB:
+      DoInplaceOpKernel<T, I_SUB>
+          <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+              cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+      break;
+  }
+}
+
+template <>
+Status DoInplace(const Device& d, InplaceOpType op, const Tensor& i,
+                 const Tensor& v, Tensor* y) {
+  CHECK_EQ(v.dtype(), y->dtype());
+  switch (v.dtype()) {
+#define CASE(type)                     \
+  case DataTypeToEnum<type>::value:    \
+    DoInplaceOp<type>(d, op, i, v, y); \
+    break;
+
+    CASE(float)
+    CASE(double)
+    CASE(Eigen::half)
+    CASE(int64)
+#undef CASE
+    default:
+      return errors::InvalidArgument("Unsupported data type: ", v.dtype());
+  }
+  return Status::OK();
+}
+
+template <>
+Status DoCopy(const Device& d, const Tensor& x, Tensor* y) {
+  CHECK_EQ(x.dtype(), y->dtype());
+  switch (x.dtype()) {
+#define CASE(type)                              \
+  case DataTypeToEnum<type>::value:             \
+    y->flat<type>().device(d) = x.flat<type>(); \
+    break;
+
+    CASE(float)
+    CASE(double)
+    CASE(Eigen::half)
+    CASE(int64)
+#undef CASE
+    default:
+      return errors::InvalidArgument("Unsupported dtype: ", x.dtype());
+  }
+  return Status::OK();
+}
+
 }  // end namespace functor
 }  // namespace tensorflow
 #endif  // GOOGLE_CUDA
index 4b119e2..2a8b9f9 100644 (file)
@@ -27,6 +27,7 @@ namespace tensorflow {
 using shape_inference::DimensionHandle;
 using shape_inference::InferenceContext;
 using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
 
 namespace {
 
@@ -341,6 +342,50 @@ REGISTER_OP("Pack")
       return Status::OK();
     });
 
+REGISTER_OP("DeepCopy")
+    .Input("x: T")
+    .Output("y: T")
+    .Attr("T: type")
+    .SetIsStateful()
+    .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceUpdate")
+    .Input("x: T")
+    .Input("i: int32")
+    .Input("v: T")
+    .Output("y: T")
+    .Attr("T: type")
+    .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceAdd")
+    .Input("x: T")
+    .Input("i: int32")
+    .Input("v: T")
+    .Output("y: T")
+    .Attr("T: type")
+    .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceSub")
+    .Input("x: T")
+    .Input("i: int32")
+    .Input("v: T")
+    .Output("y: T")
+    .Attr("T: type")
+    .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("Empty")
+    .Input("shape: int32")
+    .Output("output: dtype")
+    .Attr("dtype: type")
+    .Attr("init: bool = false")
+    .SetIsStateful()
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle out;
+      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+      c->set_output(0, out);
+      return Status::OK();
+    });
+
 // --------------------------------------------------------------------------
 REGISTER_OP("Unpack")
     .Input("value: T")
index a8f1318..01962fc 100644 (file)
@@ -1616,7 +1616,10 @@ py_library(
 
 py_library(
     name = "array_ops",
-    srcs = ["ops/array_ops.py"],
+    srcs = [
+        "ops/array_ops.py",
+        "ops/inplace_ops.py",
+    ],
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops_gen",
index 6c34ea1..3033b48 100644 (file)
@@ -1191,6 +1191,22 @@ cuda_py_test(
 )
 
 cuda_py_test(
+    name = "inplace_ops_test",
+    size = "small",
+    srcs = ["inplace_ops_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+    ],
+    shard_count = 10,
+)
+
+cuda_py_test(
     name = "batch_matmul_op_test",
     size = "small",
     srcs = ["batch_matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
new file mode 100644 (file)
index 0000000..0f95e13
--- /dev/null
@@ -0,0 +1,198 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for inplace_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import inplace_ops
+from tensorflow.python.platform import test as test_lib
+
+
+class InplaceOpsTest(test_util.TensorFlowTestCase):
+
+  def testBasicUpdate(self):
+    for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+      with self.test_session(use_gpu=True):
+        x = array_ops.ones([7, 3], dtype)
+        y = np.ones([7, 3], dtype.as_numpy_dtype)
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_update(x, [3], array_ops.ones([1, 3], dtype))
+        y[3, :] = 1
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_update(x, [-1],
+                                       array_ops.ones([1, 3], dtype) * 2)
+        y[-1, :] = 2
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_update(x, 5, array_ops.ones([3], dtype) * 7)
+        y[5, :] = 7
+        self.assertAllClose(x.eval(), y)
+
+  def testBasicUpdateBool(self):
+    with self.test_session(use_gpu=True):
+      x = array_ops.ones([7, 3], dtypes.bool)
+      y = np.ones([7, 3], dtypes.bool.as_numpy_dtype)
+      self.assertAllClose(x.eval(), y)
+      x = inplace_ops.inplace_update(x, [3], array_ops.ones([1, 3],
+                                                            dtypes.bool))
+      y[3, :] = True
+      self.assertAllClose(x.eval(), y)
+      x = inplace_ops.inplace_update(x, [-1],
+                                     array_ops.zeros([1, 3], dtypes.bool))
+      y[-1, :] = False
+      self.assertAllClose(x.eval(), y)
+      x = inplace_ops.inplace_update(x, 5, array_ops.zeros([3], dtypes.bool))
+      y[5, :] = False
+      self.assertAllClose(x.eval(), y)
+
+  def testBasicAdd(self):
+    for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+      with self.test_session(use_gpu=True):
+        x = array_ops.ones([7, 3], dtype)
+        y = np.ones([7, 3], dtype.as_numpy_dtype)
+        self.assertAllClose(x.eval(), y)
+        x = array_ops.inplace_add(x, [3], array_ops.ones([1, 3], dtype))
+        y[3, :] += 1
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_add(x, [-1], array_ops.ones([1, 3], dtype) * 2)
+        y[-1, :] += 2
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_add(x, 5, array_ops.ones([3], dtype) * 7)
+        y[5, :] += 7
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_add(x, None, array_ops.ones([7, 3], dtype) * 99)
+        y[:, :] += 99
+        self.assertAllClose(x.eval(), y)
+
+  def testBasicSub(self):
+    for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+      with self.test_session(use_gpu=True):
+        x = array_ops.ones([7, 3], dtype)
+        y = np.ones([7, 3], dtype.as_numpy_dtype)
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_sub(x, [3], array_ops.ones([1, 3], dtype))
+        y[3, :] -= 1
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_sub(x, [-1], array_ops.ones([1, 3], dtype) * 2)
+        y[-1, :] -= 2
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_sub(x, 5, array_ops.ones([3], dtype) * 7)
+        y[5, :] -= 7
+        self.assertAllClose(x.eval(), y)
+        x = inplace_ops.inplace_sub(x, None, array_ops.ones([7, 3], dtype) * 99)
+        y[:, :] -= 99
+        self.assertAllClose(x.eval(), y)
+
+  def testRandom(self):
+    with self.test_session(use_gpu=True):
+      d0, d1, d2 = 100, 3, 5
+      x = array_ops.zeros([d0, d1, d2])
+      y = np.zeros([d0, d1, d2])
+      for _ in xrange(20):
+        idx = np.random.choice(d0, d0 // 10, replace=False)
+        val = np.random.randint(10, size=(d0 // 10, d1, d2))
+        op = np.random.randint(3)
+        if op == 0:
+          x = inplace_ops.inplace_update(x, idx, val)
+          y[idx, :] = val
+        elif op == 1:
+          x = inplace_ops.inplace_add(x, idx, val)
+          y[idx, :] += val
+        elif op == 2:
+          x = inplace_ops.inplace_sub(x, idx, val)
+          y[idx, :] -= val
+        self.assertAllClose(x.eval(), y)
+
+  def testRandom1D(self):
+    with self.test_session(use_gpu=True):
+      d0 = 100
+      x = array_ops.zeros([d0])
+      y = np.zeros([d0])
+      for _ in xrange(20):
+        idx = np.random.choice(d0, d0 // 10, replace=False)
+        val = np.random.randint(10, size=(d0 // 10))
+        op = np.random.randint(3)
+        if op == 0:
+          x = inplace_ops.inplace_update(x, idx, val)
+          y[idx] = val
+        elif op == 1:
+          x = inplace_ops.inplace_add(x, idx, val)
+          y[idx] += val
+        elif op == 2:
+          x = inplace_ops.inplace_sub(x, idx, val)
+          y[idx] -= val
+        self.assertAllClose(x.eval(), y)
+
+  def testAlias(self):
+    with self.test_session(use_gpu=True) as sess:
+      x = array_ops.ones([2, 3])
+      y = inplace_ops.alias_inplace_add(x, [0], [[1, 2, 3]])
+      with ops.control_dependencies([y]):
+        z = array_ops.identity(x)
+        _, vy, vz = sess.run([x, y, z])
+      self.assertAllClose(vy, vz)
+
+  def testError(self):
+    with self.test_session():
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   "must be a vector"):
+        _ = inplace_ops.inplace_update([[1.]], [[0]], [[10]]).eval()
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   "x and v shape doesn't match"):
+        _ = inplace_ops.inplace_update([[1.]], [0], [10]).eval()
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   "i and x shape doesn't match"):
+        _ = inplace_ops.inplace_update([[1.]], [0, 1], [[10]]).eval()
+
+  def testEmpty(self):
+    for dtype in [
+        dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool
+    ]:
+      with self.test_session(use_gpu=True):
+        test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)]
+        for shape in test_shapes:
+          val = inplace_ops.empty(shape, dtype).eval()
+          self.assertEqual(val.shape, shape)
+          self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+          val = inplace_ops.empty(shape, dtype, init=True).eval()
+          self.assertEqual(val.shape, shape)
+          self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+          self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype))
+          val = inplace_ops.empty_like(array_ops.zeros(shape, dtype)).eval()
+          self.assertEqual(val.shape, shape)
+          self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+          val = inplace_ops.empty_like(
+              array_ops.zeros(shape, dtype), init=True).eval()
+          self.assertEqual(val.shape, shape)
+          self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+          self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype))
+
+        val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval()
+        self.assertEqual(val.tolist(), [[b"", b""]])
+
+        val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval()
+        self.assertEqual(val.tolist(), [[b"", b""]])
+
+
+if __name__ == "__main__":
+  test_lib.main()
diff --git a/tensorflow/python/ops/inplace_ops.py b/tensorflow/python/ops/inplace_ops.py
new file mode 100644 (file)
index 0000000..e5b0000
--- /dev/null
@@ -0,0 +1,227 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Inplace operations.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _inplace_helper(x, i, v, op):
+  """Applies an inplace op on (x, i, v).
+
+  op is one of gen_array_ops.alias_inplace_update,
+  gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub.
+
+  If i is None, x and v must be the same shape. Computes
+    x op v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    x[i, :] op v;
+  Otherwise, x and v must have the same rank. Computes
+    x[i, :] op v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+    op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub.
+
+  Returns:
+    Returns x.
+
+  """
+  x = ops.convert_to_tensor(x)
+  v = ops.convert_to_tensor(v, x.dtype)
+  if i is None:
+    # Full tensor.
+    return array_ops.reshape(
+        op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])),
+        array_ops.shape(x))
+  i = math_ops.to_int32(i)
+  if i.get_shape().ndims == 0:
+    # Single 0-dim update.
+    return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0))
+  return op(x, i, v)
+
+
+def alias_inplace_update(x, i, v):
+  """Applies an inplace update on input x at index i with value v. Aliases x.
+
+  If i is None, x and v must be the same shape. Computes
+    x = v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    x[i, :] = v;
+  Otherwise, x and v must have the same rank. Computes
+    x[i, :] = v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns x.
+
+  """
+  return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
+
+
+def alias_inplace_add(x, i, v):
+  """Applies an inplace add on input x at index i with value v. Aliases x.
+
+  If i is None, x and v must be the same shape. Computes
+    x += v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    x[i, :] += v;
+  Otherwise, x and v must have the same rank. Computes
+    x[i, :] += v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns x.
+
+  """
+  return _inplace_helper(x, i, v, gen_array_ops.inplace_add)
+
+
+def alias_inplace_sub(x, i, v):
+  """Applies an inplace sub on input x at index i with value v. Aliases x.
+
+  If i is None, x and v must be the same shape. Computes
+    x -= v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    x[i, :] -= v;
+  Otherwise, x and v must have the same rank. Computes
+    x[i, :] -= v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns x.
+
+  """
+  return _inplace_helper(x, i, v, gen_array_ops.inplace_sub)
+
+
+def empty_like(x, init=None):
+  """Returns a non-initialized tensor with the same shape and dtype as x.
+
+  Args:
+    x: A Tensor.
+    init: Initialize the returned tensor with the default value of
+      x.dtype(), if True. Otherwise, do not initialize. Defaults to
+      None.
+
+  Returns:
+    A tensor y, whose dtype and shape are the same as those of x.
+    y is guaranteed not to be an alias of x. Upon return, y may contain
+    arbitrary data.
+
+  """
+  x = ops.convert_to_tensor(x)
+  return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init)
+
+
+def inplace_update(x, i, v):
+  """Applies an inplace update on input x at index i with value v.
+
+  Note that this function is not actually inplace - it allocates
+  a copy of x.  The utility is not avoiding memory copies but rather
+  specifying a sparse update.
+
+  If i is None, x and v must be the same shape. Computes
+    y = x; y = v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    y = x; y[i, :] = v;
+  Otherwise, x and v must have the same rank. Computes
+    y = x; y[i, :] = v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns y, which is guaranteed not to be an alias of x.
+
+  """
+  return alias_inplace_update(gen_array_ops.deep_copy(x), i, v)
+
+
+def inplace_add(x, i, v):
+  """Applies an inplace add on input x at index i with value v.
+
+  Note that this function is not actually inplace - it allocates
+  a copy of x.  The utility is not avoiding memory copies but rather
+  specifying a sparse update.
+
+  If i is None, x and v must be the same shape. Computes
+    y = x; y += v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    y = x; y[i, :] += v;
+  Otherwise, x and v must have the same rank. Computes
+    y = x; y[i, :] += v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns y, which is guaranteed not to be an alias of x.
+
+  """
+  return alias_inplace_add(gen_array_ops.deep_copy(x), i, v)
+
+
+def inplace_sub(x, i, v):
+  """Applies an inplace sub on input x at index i with value v.
+
+  Note that this function is not actually inplace - it allocates
+  a copy of x.  The utility is not avoiding memory copies but rather
+  specifying a sparse update.
+
+  If i is None, x and v must be the same shape. Computes
+    y = x; y -= v;
+  If i is a scalar, x has a rank 1 higher than v's. Computes
+    y = x; y[i, :] -= v;
+  Otherwise, x and v must have the same rank. Computes
+    y = x; y[i, :] -= v;
+
+  Args:
+    x: A Tensor.
+    i: None, a scalar or a vector.
+    v: A Tensor.
+
+  Returns:
+    Returns y, which is guaranteed not to be an alias of x.
+
+  """
+  return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v)
+
+empty = gen_array_ops.empty