Extended scatter operations to work with a scalar update parameter and added scatter...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Mar 2018 23:00:14 +0000 (16:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 25 Mar 2018 11:21:07 +0000 (04:21 -0700)
PiperOrigin-RevId: 190289664

34 files changed:
tensorflow/core/api_def/base_api/api_def_ResourceScatterAdd.pbtxt
tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ScatterAdd.pbtxt
tensorflow/core/api_def/base_api/api_def_ScatterDiv.pbtxt
tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/base_api/api_def_ScatterMul.pbtxt
tensorflow/core/api_def/base_api/api_def_ScatterSub.pbtxt
tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt [new file with mode: 0644]
tensorflow/core/kernels/resource_variable_ops.cc
tensorflow/core/kernels/scatter_functor.cc
tensorflow/core/kernels/scatter_functor.h
tensorflow/core/kernels/scatter_functor_gpu.cu.cc
tensorflow/core/kernels/scatter_functor_gpu.cu.h
tensorflow/core/kernels/scatter_op.cc
tensorflow/core/kernels/scatter_op_gpu.cu.cc
tensorflow/core/kernels/scatter_op_test.cc
tensorflow/core/ops/resource_variable_ops.cc
tensorflow/core/ops/state_ops.cc
tensorflow/docs_src/api_guides/python/state_ops.md
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/kernel_tests/scatter_ops_test.py
tensorflow/python/ops/standard_ops.py
tensorflow/python/ops/state_ops.py
tensorflow/tools/api/golden/tensorflow.pbtxt

index 9e0de08..4eb6eb4 100644 (file)
@@ -34,7 +34,7 @@ This operation computes
 Duplicate entries are handled correctly: if multiple `indices` reference
 the same location, their contributions add.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 
 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644 (file)
index 0000000..47148f7
--- /dev/null
@@ -0,0 +1,43 @@
+op {
+  graph_op_name: "ResourceScatterDiv"
+  in_arg {
+    name: "resource"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to add to `ref`.
+END
+  }
+  summary: "Divides sparse updates into the variable referenced by `resource`."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] /= updates[...]
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] /= updates[i, ...]
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644 (file)
index 0000000..71f06d9
--- /dev/null
@@ -0,0 +1,43 @@
+op {
+  graph_op_name: "ResourceScatterMax"
+  in_arg {
+    name: "resource"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to add to `ref`.
+END
+  }
+  summary: "Reduces sparse updates into the variable referenced by `resource` using the `max` operation."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644 (file)
index 0000000..08e40ee
--- /dev/null
@@ -0,0 +1,43 @@
+op {
+  graph_op_name: "ResourceScatterMin"
+  in_arg {
+    name: "resource"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to add to `ref`.
+END
+  }
+  summary: "Reduces sparse updates into the variable referenced by `resource` using the `min` operation."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions are combined.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644 (file)
index 0000000..5c63549
--- /dev/null
@@ -0,0 +1,43 @@
+op {
+  graph_op_name: "ResourceScatterMul"
+  in_arg {
+    name: "resource"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to add to `ref`.
+END
+  }
+  summary: "Multiplies sparse updates into the variable referenced by `resource`."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] *= updates[...]
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] *= updates[i, ...]
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644 (file)
index 0000000..e71e60c
--- /dev/null
@@ -0,0 +1,43 @@
+op {
+  graph_op_name: "ResourceScatterSub"
+  in_arg {
+    name: "resource"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to add to `ref`.
+END
+  }
+  summary: "Subtracts sparse updates from the variable referenced by `resource`."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] -= updates[...]
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] -= updates[i, ...]
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+</div>
+END
+}
index 4b5201f..9da9d09 100644 (file)
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
 Duplicate entries are handled correctly: if multiple `indices` reference
 the same location, their contributions add.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 
 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
index 771cf0b..8e99718 100644 (file)
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
 Duplicate entries are handled correctly: if multiple `indices` reference
 the same location, their contributions divide.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMax.pbtxt
new file mode 100644 (file)
index 0000000..7b52dad
--- /dev/null
@@ -0,0 +1,60 @@
+op {
+  graph_op_name: "ScatterMax"
+  in_arg {
+    name: "ref"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+  }
+  out_arg {
+    name: "output_ref"
+    description: <<END
+= Same as `ref`.  Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+  }
+  summary: "Reduces sparse updates into a variable reference using the `max` operation."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] = max(ref[indices, ...], updates[...])
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterMin.pbtxt
new file mode 100644 (file)
index 0000000..721ac0f
--- /dev/null
@@ -0,0 +1,60 @@
+op {
+  graph_op_name: "ScatterMin"
+  in_arg {
+    name: "ref"
+    description: <<END
+Should be from a `Variable` node.
+END
+  }
+  in_arg {
+    name: "indices"
+    description: <<END
+A tensor of indices into the first dimension of `ref`.
+END
+  }
+  in_arg {
+    name: "updates"
+    description: <<END
+A tensor of updated values to reduce into `ref`.
+END
+  }
+  out_arg {
+    name: "output_ref"
+    description: <<END
+= Same as `ref`.  Returned as a convenience for operations that want
+to use the updated values after the update is done.
+END
+  }
+  attr {
+    name: "use_locking"
+    description: <<END
+If True, the update will be protected by a lock;
+otherwise the behavior is undefined, but may exhibit less contention.
+END
+  }
+  summary: "Reduces sparse updates into a variable reference using the `min` operation."
+  description: <<END
+This operation computes
+
+    # Scalar indices
+    ref[indices, ...] = min(ref[indices, ...], updates[...])
+
+    # Vector indices (for each i)
+    ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
+
+    # High rank indices (for each i, ..., j)
+    ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], updates[i, ..., j, ...])
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions combine.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" alt>
+</div>
+END
+}
index a51f571..b9e293b 100644 (file)
@@ -53,6 +53,6 @@ This makes it easier to chain operations that need to use the reset value.
 Duplicate entries are handled correctly: if multiple `indices` reference
 the same location, their contributions multiply.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 END
 }
index c0d3a4a..d12b3e6 100644 (file)
@@ -51,7 +51,7 @@ This makes it easier to chain operations that need to use the reset value.
 Duplicate entries are handled correctly: if multiple `indices` reference
 the same location, their (negated) contributions add.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 
 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterSub.png" alt>
index c44dbbd..4804908 100644 (file)
@@ -54,7 +54,7 @@ If values in `ref` is to be updated more than once, because there are
 duplicate entries in `indices`, the order at which the updates happen
 for each value is undefined.
 
-Requires `updates.shape = indices.shape + ref.shape[1:]`.
+Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`.
 
 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterDiv.pbtxt
new file mode 100644 (file)
index 0000000..56b5a46
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ResourceScatterDiv"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMax.pbtxt
new file mode 100644 (file)
index 0000000..8119bcc
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ResourceScatterMax"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMin.pbtxt
new file mode 100644 (file)
index 0000000..d874aef
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ResourceScatterMin"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterMul.pbtxt
new file mode 100644 (file)
index 0000000..365a37f
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ResourceScatterMul"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterSub.pbtxt
new file mode 100644 (file)
index 0000000..72dc5bf
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ResourceScatterSub"
+  visibility: HIDDEN
+}
index aecad01..e134e47 100644 (file)
@@ -619,22 +619,35 @@ class ResourceScatterUpdateOp : public OpKernel {
     if (N > 0) {
       auto indices_flat = indices.flat<Index>();
       auto params_flat = params->flat_outer_dims<T>();
-      int64 num_updates = updates.NumElements();
-      OP_REQUIRES(c, num_updates % N == 0,
-                  errors::InvalidArgument(
-                      "shape of indices (", indices.shape().DebugString(),
-                      ") is not compatible with the shape of updates (",
-                      updates.shape().DebugString(), ")"));
-      auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
-
-      functor::ScatterFunctor<Device, T, Index, op> functor;
-      const Index bad_i = functor(c, c->template eigen_device<Device>(),
-                                  params_flat, updates_flat, indices_flat);
-      OP_REQUIRES(c, bad_i < 0,
-                  errors::InvalidArgument(
-                      "indices", SliceDebugString(indices.shape(), bad_i),
-                      " = ", indices_flat(bad_i), " is not in [0, ",
-                      params->dim_size(0), ")"));
+      if (TensorShapeUtils::IsScalar(updates.shape())) {
+        const auto update = updates.scalar<T>();
+
+        functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<Device>(),
+                                    params_flat, update, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params->dim_size(0), ")"));
+      } else {
+        int64 num_updates = updates.NumElements();
+        OP_REQUIRES(c, num_updates % N == 0,
+                    errors::InvalidArgument(
+                        "shape of indices (", indices.shape().DebugString(),
+                        ") is not compatible with the shape of updates (",
+                        updates.shape().DebugString(), ")"));
+        auto updates_flat = updates.shaped<T, 2>({N, num_updates / N});
+
+        functor::ScatterFunctor<Device, T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<Device>(),
+                                    params_flat, updates_flat, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params->dim_size(0), ")"));
+      }
     }
   }
 };
@@ -652,35 +665,51 @@ class ResourceScatterUpdateOp : public OpKernel {
   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
 
-// TODO(apassos) add the other types here.
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev)                \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev)                \
   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterAdd",    \
                           scatter_op::UpdateOp::ADD);         \
+  REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterSub",    \
+                          scatter_op::UpdateOp::SUB);         \
+  REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMul",    \
+                          scatter_op::UpdateOp::MUL);         \
+  REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterDiv",    \
+                          scatter_op::UpdateOp::DIV);         \
   REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterUpdate", \
                           scatter_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_MINMAX(type, dev)                 \
+  REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMin", \
+                          scatter_op::UpdateOp::MIN);      \
+  REGISTER_SCATTER_KERNEL(type, dev, "ResourceScatterMax", \
+                          scatter_op::UpdateOp::MAX);
 
 // Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
-  REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+  REGISTER_SCATTER_ARITHMETIC(type, CPU);
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
 
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
 
 REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
                         scatter_op::UpdateOp::ASSIGN);
 
 // Registers GPU kernels.
 #if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
-  REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+  REGISTER_SCATTER_ARITHMETIC(type, GPU);
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
 
 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
 
 #endif  // GOOGLE_CUDA
 
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
 #undef REGISTER_SCATTER_KERNEL
 #undef REGISTER_SCATTER_KERNEL_INDEX
 
index 7eba828..cf54081 100644 (file)
@@ -26,21 +26,30 @@ typedef Eigen::GpuDevice GPUDevice;
 namespace functor {
 
 // Forward declarations of the functor specializations for GPU.
-#define DECLARE_GPU_SPECS_OP(T, Index, op)                   \
-  template <>                                                \
-  Index ScatterFunctor<GPUDevice, T, Index, op>::operator()( \
-      OpKernelContext* c, const GPUDevice& d,                \
-      typename TTypes<T>::Matrix params,                     \
-      typename TTypes<T>::ConstMatrix updates,               \
-      typename TTypes<Index>::ConstFlat indices);            \
-  extern template struct ScatterFunctor<GPUDevice, T, Index, op>;
+#define DECLARE_GPU_SPECS_OP(T, Index, op)                         \
+  template <>                                                      \
+  Index ScatterFunctor<GPUDevice, T, Index, op>::operator()(       \
+      OpKernelContext* c, const GPUDevice& d,                      \
+      typename TTypes<T>::Matrix params,                           \
+      typename TTypes<T>::ConstMatrix updates,                     \
+      typename TTypes<Index>::ConstFlat indices);                  \
+  extern template struct ScatterFunctor<GPUDevice, T, Index, op>;  \
+  template <>                                                      \
+  Index ScatterScalarFunctor<GPUDevice, T, Index, op>::operator()( \
+      OpKernelContext* c, const GPUDevice& d,                      \
+      typename TTypes<T>::Matrix params,                           \
+      const typename TTypes<T>::ConstScalar update,                \
+      typename TTypes<Index>::ConstFlat indices);                  \
+  extern template struct ScatterScalarFunctor<GPUDevice, T, Index, op>;
 
 #define DECLARE_GPU_SPECS_INDEX(T, Index)                       \
   DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
   DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD);    \
   DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);    \
   DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL);    \
-  DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+  DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);    \
+  DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN);    \
+  DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
 
 #define DECLARE_GPU_SPECS(T)         \
   DECLARE_GPU_SPECS_INDEX(T, int32); \
index 079f15e..5266664 100644 (file)
@@ -18,6 +18,8 @@ limitations under the License.
 
 #include <type_traits>
 
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/bounds_check.h"
 #include "tensorflow/core/platform/types.h"
@@ -33,7 +35,7 @@ typedef Eigen::SyclDevice SYCLDevice;
 
 namespace scatter_op {
 
-enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
+enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
 
 namespace internal {
 
@@ -45,6 +47,10 @@ struct Assign<scatter_op::UpdateOp::ASSIGN> {
   static void Run(Params p, Update u) {
     p = u;
   }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p.setConstant(u);
+  }
 };
 template <>
 struct Assign<scatter_op::UpdateOp::ADD> {
@@ -52,6 +58,10 @@ struct Assign<scatter_op::UpdateOp::ADD> {
   static void Run(Params p, Update u) {
     p += u;
   }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p + u;
+  }
 };
 template <>
 struct Assign<scatter_op::UpdateOp::SUB> {
@@ -59,6 +69,10 @@ struct Assign<scatter_op::UpdateOp::SUB> {
   static void Run(Params p, Update u) {
     p -= u;
   }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p + static_cast<Update>(-u);
+  }
 };
 template <>
 struct Assign<scatter_op::UpdateOp::MUL> {
@@ -66,6 +80,10 @@ struct Assign<scatter_op::UpdateOp::MUL> {
   static void Run(Params p, Update u) {
     p *= u;
   }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p * u;
+  }
 };
 template <>
 struct Assign<scatter_op::UpdateOp::DIV> {
@@ -73,6 +91,34 @@ struct Assign<scatter_op::UpdateOp::DIV> {
   static void Run(Params p, Update u) {
     p /= u;
   }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p / u;
+  }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MIN> {
+  // This method requires that Params and Update are tensor types.
+  template <typename Params, typename Update>
+  static void Run(Params p, Update u) {
+    p = p.cwiseMin(u);
+  }
+  // Same thing, but for Update being a scalar type.
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p.cwiseMin(u);
+  }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::MAX> {
+  template <typename Params, typename Update>
+  static void Run(Params p, Update u) {
+    p = p.cwiseMax(u);
+  }
+  template <typename Params, typename Update>
+  static void RunScalar(Params p, Update u) {
+    p = p.cwiseMax(u);
+  }
 };
 
 #ifdef TENSORFLOW_USE_SYCL
@@ -117,6 +163,22 @@ struct AssignSYCL<scatter_op::UpdateOp::DIV> {
     p.device(d) = p / u;
   }
 };
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MIN> {
+  template <typename Device, typename Params, typename Update>
+  static void Run(Device d, Params p, Update u) {
+    p.device(d) = p.cwiseMin(u);
+  }
+};
+
+template <>
+struct AssignSYCL<scatter_op::UpdateOp::MAX> {
+  template <typename Device, typename Params, typename Update>
+  static void Run(Device d, Params p, Update u) {
+    p.device(d) = p.cwiseMax(u);
+  }
+};
 #endif  // TENSORFLOW_USE_SYCL
 
 }  // namespace internal
@@ -241,6 +303,112 @@ struct ScatterFunctorSYCL {
 };
 #endif  // TENSORFLOW_USE_SYCL
 
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor {
+  Index operator()(OpKernelContext* c, const Device& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices);
+};
+
+template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase {
+  Index operator()(OpKernelContext* c, const Device& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    for (Index i = 0; i < N; i++) {
+      // Grab the index and check its validity.  An earlier version of the
+      // code checked it and then grabbed it from memory a second time, which
+      // was a security risk since it could have changed in between.
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Broadcast update to params[index]
+      scatter_op::internal::Assign<op>::RunScalar(
+          params.template chip<0>(index), update());
+    }
+    return -1;
+  }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
+  Index operator()(OpKernelContext* c, const SYCLDevice& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    for (Index i = 0; i < N; i++) {
+      // Grab the index and check its validity.  An earlier version of the
+      // code checked it and then grabbed it from memory a second time, which
+      // was a security risk since it could have changed in between.
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Broadcast update to params[index]
+      scatter_op::internal::AssignSYCL<op>::RunScalar(
+          d, params.template chip<0>(index), update);
+    }
+    return -1;
+  }
+};
+#endif  // TENSORFLOW_USE_SYCL
+
+template <typename T, typename Index>
+struct ScatterScalarFunctorBase<CPUDevice, T, Index,
+                                scatter_op::UpdateOp::ASSIGN> {
+  Index operator()(OpKernelContext* c, const CPUDevice& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    for (Index i = 0; i < N; i++) {
+      // Grab the index and check its validity.  An earlier version of the
+      // code checked it and then grabbed it from memory a second time, which
+      // was a security risk since it could have changed in between.
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Broadcast update to params[index]
+      scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
+          params.template chip<0>(index), update());
+    }
+    return -1;
+  }
+};
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<CPUDevice, T, Index, op>
+    : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctorSYCL {
+  Index operator()(OpKernelContext* c, const SYCLDevice& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::Flat indices) {
+    // indices and params sizes were validated in DoCompute().
+    const Index N = static_cast<Index>(indices.size());
+    const Index limit = static_cast<Index>(params.dimension(0));
+    for (Index i = 0; i < N; i++) {
+      const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+      if (!FastBoundsCheck(index, limit)) return i;
+      // Broadcast update to params[index]
+      scatter_op::internal::AssignSYCL<op>::Run(
+          d, params.template chip<0>(index), update());
+    }
+    return -1;
+  }
+};
+#endif  // TENSORFLOW_USE_SYCL
+
 }  // namespace functor
 }  // namespace tensorflow
 
index 5297299..59911bf 100644 (file)
@@ -23,15 +23,18 @@ namespace tensorflow {
 
 typedef Eigen::GpuDevice GPUDevice;
 
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
-  template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op)                           \
+  template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+  template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
 
 #define DEFINE_GPU_SPECS_INDEX(T, Index)                       \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD);    \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);    \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL);    \
-  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);    \
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN);    \
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
 
 #define DEFINE_GPU_SPECS(T)         \
   DEFINE_GPU_SPECS_INDEX(T, int32); \
index be18658..70809e4 100644 (file)
@@ -29,12 +29,53 @@ namespace tensorflow {
 
 typedef Eigen::GpuDevice GPUDevice;
 
+namespace scatter_op_gpu {
+
+template <typename T, scatter_op::UpdateOp op>
+struct ScatterOpKernelBody;
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> {
+  __device__ void operator()(T* dest, T src) const { *dest = src; }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
+  __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); }
+};
+
 template <typename T, typename Index, scatter_op::UpdateOp op>
 __global__ void ScatterOpCustomKernel(T* params, const T* updates,
                                       const Index* indices,
                                       Index first_dim_size, Index updates_size,
                                       Index indices_size) {
   Index update_block = updates_size / indices_size;
+  ScatterOpKernelBody<T, op> body;
   CUDA_1D_KERNEL_LOOP(i, updates_size) {
     int indices_i = i / update_block;
     int updates_i = i;
@@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
       continue;
     }
     int params_i = param_first_index * update_block + (i % update_block);
-    switch (op) {
-      case scatter_op::UpdateOp::ASSIGN: {
-        params[params_i] = ldg(updates + updates_i);
-        break;
-      }
-      case scatter_op::UpdateOp::ADD: {
-        CudaAtomicAdd(params + params_i, ldg(updates + updates_i));
-        break;
-      }
-      case scatter_op::UpdateOp::SUB: {
-        CudaAtomicSub(params + params_i, ldg(updates + updates_i));
-        break;
-      }
-      case scatter_op::UpdateOp::MUL: {
-        CudaAtomicMul(params + params_i, ldg(updates + updates_i));
-        break;
-      }
-      case scatter_op::UpdateOp::DIV: {
-        CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
-        break;
-      }
+    body(&params[params_i], ldg(updates + updates_i));
+  }
+}
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
+                                            const Index* indices,
+                                            Index first_dim_size,
+                                            Index indices_size,
+                                            Index synthesized_updates_size) {
+  Index update_block = synthesized_updates_size / indices_size;
+  ScatterOpKernelBody<T, op> body;
+  CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) {
+    int indices_i = i / update_block;
+    int param_first_index = indices[indices_i];
+    const T update_val = *update;
+    if (!(param_first_index >= 0 && param_first_index < first_dim_size)) {
+      // Ignore indices that are out of range.
+      continue;
     }
+    int params_i = param_first_index * update_block + (i % update_block);
+    body(&params[params_i], update_val);
   }
 }
 
+}  // namespace scatter_op_gpu
+
 namespace functor {
 // Specialization for a GPU device.
 template <typename T, typename Index, scatter_op::UpdateOp op>
@@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
     const Index indices_size = indices.size();
     const Index updates_size = updates.size();
     CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
-    ScatterOpCustomKernel<T, Index, op>
+    scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>
         <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
             params.data(), updates.data(), indices.data(), first_dim_size,
             updates_size, indices_size);
@@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
   }
 };
 
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
+  Index operator()(OpKernelContext* c, const GPUDevice& d,
+                   typename TTypes<T>::Matrix params,
+                   const typename TTypes<T>::ConstScalar update,
+                   typename TTypes<Index>::ConstFlat indices) {
+    // TODO(b/31801742): Implement indices range check. The hardest part is
+    // with returning a value after the range check, as we do not want to do
+    // device to host memcpy during a stream.
+    const Index first_dim_size = params.dimension(0);
+    const Index indices_size = indices.size();
+    const Index synthesized_updates_size = indices_size * params.dimension(1);
+    CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d);
+    scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>
+        <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+            params.data(), update.data(), indices.data(), first_dim_size,
+            indices_size, synthesized_updates_size);
+    return -1;
+  }
+};
+
 }  // namespace functor
 }  // namespace tensorflow
 
index 2821653..0fbde76 100644 (file)
@@ -38,6 +38,7 @@ typedef Eigen::SyclDevice SYCLDevice;
 // Check whether updates.shape = indices.shape + params.shape[1:]
 static bool ValidShapes(const Tensor& params, const Tensor& updates,
                         const Tensor& indices) {
+  if (updates.dims() == 0) return true;
   if (updates.dims() != indices.dims() + params.dims() - 1) return false;
   for (int d = 0; d < indices.dims(); d++) {
     if (updates.dim_size(d) != indices.dim_size(d)) {
@@ -61,11 +62,11 @@ static void DoValidationChecking(OpKernelContext* c, const Tensor& params,
                                       params.shape().DebugString()));
   OP_REQUIRES(
       c, ValidShapes(params, updates, indices),
-      errors::InvalidArgument(
-          "Must have updates.shape = indices.shape + params.shape[1:], got ",
-          "updates.shape ", updates.shape().DebugString(), ", indices.shape ",
-          indices.shape().DebugString(), ", params.shape ",
-          params.shape().DebugString()));
+      errors::InvalidArgument("Must have updates.shape = indices.shape + "
+                              "params.shape[1:] or updates.shape = [], got ",
+                              "updates.shape ", updates.shape().DebugString(),
+                              ", indices.shape ", indices.shape().DebugString(),
+                              ", params.shape ", params.shape().DebugString()));
 }
 
 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
@@ -122,16 +123,31 @@ class ScatterUpdateOp : public OpKernel {
     if (N > 0) {
       auto indices_flat = indices.flat<Index>();
       auto params_flat = params.flat_outer_dims<T>();
-      auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
-      functor::ScatterFunctor<Device, T, Index, op> functor;
-      const Index bad_i = functor(c, c->template eigen_device<Device>(),
-                                  params_flat, updates_flat, indices_flat);
-      OP_REQUIRES(
-          c, bad_i < 0,
-          errors::InvalidArgument(
-              "indices", SliceDebugString(indices.shape(), bad_i), " = ",
-              indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+      if (TensorShapeUtils::IsScalar(updates.shape()) ||
+          IsLegacyScalar(updates.shape())) {
+        const auto update = updates.scalar<T>();
+        functor::ScatterScalarFunctor<Device, T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<Device>(),
+                                    params_flat, update, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params.dim_size(0), ")"));
+      } else {
+        auto updates_flat =
+            updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+        functor::ScatterFunctor<Device, T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<Device>(),
+                                    params_flat, updates_flat, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params.dim_size(0), ")"));
+      }
     }
   }
 };
@@ -195,16 +211,31 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
 
       auto indices_flat = indices_host.flat<Index>();
       auto params_flat = params.flat_outer_dims<T>();
-      auto updates_flat = updates.shaped<T, 2>({N, updates.NumElements() / N});
-
-      functor::ScatterFunctorSYCL<T, Index, op> functor;
-      const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
-                                  params_flat, updates_flat, indices_flat);
-      OP_REQUIRES(
-          c, bad_i < 0,
-          errors::InvalidArgument(
-              "indices", SliceDebugString(indices.shape(), bad_i), " = ",
-              indices_flat(bad_i), " is not in [0, ", params.dim_size(0), ")"));
+
+      if (TensorShapeUtils::IsScalar(updates.shape())) {
+        const auto update = updates.scalar<T>();
+
+        functor::ScatterScalarFunctorSYCL<T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+                                    params_flat, update, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params.dim_size(0), ")"));
+      } else {
+        auto updates_flat =
+            updates.shaped<T, 2>({N, updates.NumElements() / N});
+
+        functor::ScatterFunctorSYCL<T, Index, op> functor;
+        const Index bad_i = functor(c, c->template eigen_device<SYCLDevice>(),
+                                    params_flat, updates_flat, indices_flat);
+        OP_REQUIRES(c, bad_i < 0,
+                    errors::InvalidArgument(
+                        "indices", SliceDebugString(indices.shape(), bad_i),
+                        " = ", indices_flat(bad_i), " is not in [0, ",
+                        params.dim_size(0), ")"));
+      }
     }
   }
 };
@@ -221,54 +252,71 @@ class ScatterUpdateOp<SYCLDevice, T, Index, op> : public OpKernel {
   REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
   REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
 
-#define REGISTER_SCATTER_ARITHEMTIC(type, dev)                                 \
+#define REGISTER_SCATTER_ARITHMETIC(type, dev)                                 \
   REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
   REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
   REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
   REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
 
+#define REGISTER_SCATTER_MINMAX(type, dev)                                     \
+  REGISTER_SCATTER_KERNEL(type, dev, "ScatterMin", scatter_op::UpdateOp::MIN); \
+  REGISTER_SCATTER_KERNEL(type, dev, "ScatterMax", scatter_op::UpdateOp::MAX);
+
 #define REGISTER_SCATTER_UPDATE(type, dev)            \
   REGISTER_SCATTER_KERNEL(type, dev, "ScatterUpdate", \
                           scatter_op::UpdateOp::ASSIGN);
 
 // Registers CPU kernels.
-#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
-  REGISTER_SCATTER_ARITHEMTIC(type, CPU);
+#define REGISTER_SCATTER_ARITHMETIC_CPU(type) \
+  REGISTER_SCATTER_ARITHMETIC(type, CPU);
+
+#define REGISTER_SCATTER_MINMAX_CPU(type) REGISTER_SCATTER_MINMAX(type, CPU);
 
 #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
 
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_CPU);
 TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
 
 // Registers GPU kernels.
 #if GOOGLE_CUDA
-#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
-  REGISTER_SCATTER_ARITHEMTIC(type, GPU);
+#define REGISTER_SCATTER_ARITHMETIC_GPU(type) \
+  REGISTER_SCATTER_ARITHMETIC(type, GPU);
+
+#define REGISTER_SCATTER_MINMAX_GPU(type) REGISTER_SCATTER_MINMAX(type, GPU);
 
 #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
 
 #endif  // GOOGLE_CUDA
 
 // Registers GPU kernels.
 #if TENSORFLOW_USE_SYCL
-#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
-  REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+#define REGISTER_SCATTER_ARITHMETIC_SYCL(type) \
+  REGISTER_SCATTER_ARITHMETIC(type, SYCL);
+
+#define REGISTER_SCATTER_MINMAX_SYCL(type) REGISTER_SCATTER_MINMAX(type, SYCL);
 
 #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
 
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_SYCL);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_SYCL);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL);
 
-#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_ARITHMETIC_SYCL
+#undef REGISTER_SCATTER_MINMAX_SYCL
 #undef REGISTER_SCATTER_UPDATE_SYCL
 #endif  // TENSORFLOW_USE_SYCL
 
-#undef REGISTER_SCATTER_ARITHEMTIC
-#undef REGISTER_SCATTER_ARITHEMTIC_CPU
-#undef REGISTER_SCATTER_ARITHEMTIC_GPU
+#undef REGISTER_SCATTER_ARITHMETIC
+#undef REGISTER_SCATTER_ARITHMETIC_CPU
+#undef REGISTER_SCATTER_ARITHMETIC_GPU
+#undef REGISTER_SCATTER_MINMAX
+#undef REGISTER_SCATTER_MINMAX_CPU
+#undef REGISTER_SCATTER_MINMAX_GPU
 #undef REGISTER_SCATTER_UPDATE
 #undef REGISTER_SCATTER_UPDATE_CPU
 #undef REGISTER_SCATTER_UPDATE_GPU
index 0b43704..0df3293 100644 (file)
@@ -24,15 +24,18 @@ namespace tensorflow {
 typedef Eigen::GpuDevice GPUDevice;
 
 // Instantiates functor specializations for GPU.
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
-  template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op)                           \
+  template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+  template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
 
 #define DEFINE_GPU_SPECS_INDEX(T, Index)                       \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD);    \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);    \
   DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL);    \
-  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);    \
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN);    \
+  DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
 
 #define DEFINE_GPU_SPECS(T)         \
   DEFINE_GPU_SPECS_INDEX(T, int32); \
index 0b8645a..5b3537b 100644 (file)
@@ -185,7 +185,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
   Status s = RunOpKernel();
   EXPECT_TRUE(StringPiece(s.ToString())
                   .contains("Must have updates.shape = indices.shape + "
-                            "params.shape[1:], got "))
+                            "params.shape[1:] or updates.shape = [], got "))
       << s;
 }
 
@@ -202,7 +202,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
   Status s = RunOpKernel();
   EXPECT_TRUE(StringPiece(s.ToString())
                   .contains("Must have updates.shape = indices.shape + "
-                            "params.shape[1:], got "))
+                            "params.shape[1:] or updates.shape = [], got "))
 
       << s;
 }
@@ -219,7 +219,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
   Status s = RunOpKernel();
   EXPECT_TRUE(StringPiece(s.ToString())
                   .contains("Must have updates.shape = indices.shape + "
-                            "params.shape[1:], got "))
+                            "params.shape[1:] or updates.shape = [], got "))
       << s;
 }
 
@@ -300,6 +300,20 @@ static void BM_ScatterDivInt64(int iters, int embedding_size) {
   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
 }
 
+static void BM_ScatterMinInt32(int iters, int embedding_size) {
+  BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMin");
+}
+static void BM_ScatterMinInt64(int iters, int embedding_size) {
+  BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMin");
+}
+
+static void BM_ScatterMaxInt32(int iters, int embedding_size) {
+  BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMax");
+}
+static void BM_ScatterMaxInt64(int iters, int embedding_size) {
+  BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMax");
+}
+
 BENCHMARK(BM_ScatterUpdateInt32)
     ->Arg(1)
     ->Arg(10)
@@ -332,5 +346,11 @@ BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
 BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
 BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
 
+BENCHMARK(BM_ScatterMinInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMinInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+BENCHMARK(BM_ScatterMaxInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMaxInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
 }  // namespace
 }  // namespace tensorflow
index 0d8cf78..3d0a6c2 100644 (file)
@@ -167,27 +167,75 @@ REGISTER_OP("ResourceGather")
       return Status::OK();
     });
 
+namespace {
+
+Status ResourceScatterUpdateShape(InferenceContext* c) {
+  ShapeAndType handle_shape_and_type;
+  TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
+  ShapeHandle var_shape = handle_shape_and_type.shape;
+  ShapeHandle indices_shape = c->input(1);
+
+  ShapeHandle unused_updates_shape;
+  ShapeHandle concat;
+  ShapeHandle var_subshape;
+  TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
+  TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
+  TF_RETURN_IF_ERROR(
+      InferenceContext::Rank(c->input(2)) == 0
+          ? Status::OK()
+          : c->Merge(c->input(2), concat, &unused_updates_shape));
+  return Status::OK();
+}
+
+}  // namespace
+
 REGISTER_OP("ResourceScatterAdd")
     .Input("resource: resource")
     .Input("indices: Tindices")
     .Input("updates: dtype")
     .Attr("dtype: numbertype")
     .Attr("Tindices: {int32, int64}")
-    .SetShapeFn([](InferenceContext* c) {
-      ShapeAndType handle_shape_and_type;
-      TF_RETURN_IF_ERROR(
-          ValidateVariableResourceHandle(c, &handle_shape_and_type));
-      ShapeHandle var_shape = handle_shape_and_type.shape;
-      ShapeHandle indices_shape = c->input(1);
+    .SetShapeFn(ResourceScatterUpdateShape);
 
-      ShapeHandle unused_updates_shape;
-      ShapeHandle concat;
-      ShapeHandle var_subshape;
-      TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
-      TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
-      TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
-      return Status::OK();
-    });
+REGISTER_OP("ResourceScatterSub")
+    .Input("resource: resource")
+    .Input("indices: Tindices")
+    .Input("updates: dtype")
+    .Attr("dtype: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMul")
+    .Input("resource: resource")
+    .Input("indices: Tindices")
+    .Input("updates: dtype")
+    .Attr("dtype: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterDiv")
+    .Input("resource: resource")
+    .Input("indices: Tindices")
+    .Input("updates: dtype")
+    .Attr("dtype: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMin")
+    .Input("resource: resource")
+    .Input("indices: Tindices")
+    .Input("updates: dtype")
+    .Attr("dtype: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(ResourceScatterUpdateShape);
+
+REGISTER_OP("ResourceScatterMax")
+    .Input("resource: resource")
+    .Input("indices: Tindices")
+    .Input("updates: dtype")
+    .Attr("dtype: numbertype")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(ResourceScatterUpdateShape);
 
 REGISTER_OP("ResourceScatterUpdate")
     .Input("resource: resource")
@@ -195,21 +243,7 @@ REGISTER_OP("ResourceScatterUpdate")
     .Input("updates: dtype")
     .Attr("dtype: type")
     .Attr("Tindices: {int32, int64}")
-    .SetShapeFn([](InferenceContext* c) {
-      ShapeAndType handle_shape_and_type;
-      TF_RETURN_IF_ERROR(
-          ValidateVariableResourceHandle(c, &handle_shape_and_type));
-      ShapeHandle var_shape = handle_shape_and_type.shape;
-      ShapeHandle indices_shape = c->input(1);
-
-      ShapeHandle unused_updates_shape;
-      ShapeHandle concat;
-      ShapeHandle var_subshape;
-      TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
-      TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
-      TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
-      return Status::OK();
-    });
+    .SetShapeFn(ResourceScatterUpdateShape);
 
 REGISTER_OP("MutexV2")
     .Attr("container: string = ''")
index 7a524b6..664f524 100644 (file)
@@ -122,7 +122,10 @@ Status ScatterUpdateShape(InferenceContext* c) {
   ShapeHandle var_subshape;
   TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
   TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
-  TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
+  TF_RETURN_IF_ERROR(
+      InferenceContext::Rank(c->input(2)) == 0
+          ? Status::OK()
+          : c->Merge(c->input(2), concat, &unused_updates_shape));
 
   c->set_output(0, var_shape);
   return Status::OK();
@@ -180,6 +183,26 @@ REGISTER_OP("ScatterDiv")
     .Attr("use_locking: bool = false")
     .SetShapeFn(ScatterUpdateShape);
 
+REGISTER_OP("ScatterMin")
+    .Input("ref: Ref(T)")
+    .Input("indices: Tindices")
+    .Input("updates: T")
+    .Output("output_ref: Ref(T)")
+    .Attr("T: {half, bfloat16, float, double, int32, int64}")
+    .Attr("Tindices: {int32, int64}")
+    .Attr("use_locking: bool = false")
+    .SetShapeFn(ScatterUpdateShape);
+
+REGISTER_OP("ScatterMax")
+    .Input("ref: Ref(T)")
+    .Input("indices: Tindices")
+    .Input("updates: T")
+    .Output("output_ref: Ref(T)")
+    .Attr("T: {half, bfloat16, float, double, int32, int64}")
+    .Attr("Tindices: {int32, int64}")
+    .Attr("use_locking: bool = false")
+    .SetShapeFn(ScatterUpdateShape);
+
 REGISTER_OP("ScatterNdUpdate")
     .Input("ref: Ref(T)")
     .Input("indices: Tindices")
index 0d612ee..ec2d877 100644 (file)
@@ -83,6 +83,8 @@ automatically by the optimizers in most cases.
 *   @{tf.scatter_sub}
 *   @{tf.scatter_mul}
 *   @{tf.scatter_div}
+*   @{tf.scatter_min}
+*   @{tf.scatter_max}
 *   @{tf.scatter_nd_update}
 *   @{tf.scatter_nd_add}
 *   @{tf.scatter_nd_sub}
index 563eeff..742564f 100644 (file)
@@ -185,6 +185,204 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
       self.assertEqual(self.evaluate(read), [[3]])
 
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterSub(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[1]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_sub(handle, [0],
+                                                     constant_op.constant(
+                                                         [[2]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[-1]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMul(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[1]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_mul(handle, [0],
+                                                     constant_op.constant(
+                                                         [[5]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[5]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterDiv(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_div(handle, [0],
+                                                     constant_op.constant(
+                                                         [[3]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[2]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMin(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_min(handle, [0],
+                                                     constant_op.constant(
+                                                         [[3]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[3]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMax(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_max(handle, [0],
+                                                     constant_op.constant(
+                                                         [[3]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[6]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterAddScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[1]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_add(handle, [0],
+                                                     constant_op.constant(
+                                                         2,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[3]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterSubScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[1]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_sub(handle, [0],
+                                                     constant_op.constant(
+                                                         2,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[-1]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMulScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[1]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_mul(handle, [0],
+                                                     constant_op.constant(
+                                                         5,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[5]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterDivScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_div(handle, [0],
+                                                     constant_op.constant(
+                                                         3,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[2]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMinScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_min(handle, [0],
+                                                     constant_op.constant(
+                                                         3,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[3]])
+
+  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+  def testScatterMaxScalar(self):
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_max(handle, [0],
+                                                     constant_op.constant(
+                                                         3,
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[6]])
+
   def testScatterUpdateString(self):
     handle = resource_variable_ops.var_handle_op(
         dtype=dtypes.string, shape=[1, 1])
@@ -196,6 +394,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
     self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
                      compat.as_bytes("b"))
 
+  def testScatterUpdateStringScalar(self):
+    handle = resource_variable_ops.var_handle_op(
+        dtype=dtypes.string, shape=[1, 1])
+    self.evaluate(
+        resource_variable_ops.assign_variable_op(handle,
+                                                 constant_op.constant(
+                                                     [["a"]],
+                                                     dtype=dtypes.string)))
+    self.evaluate(
+        resource_variable_ops.resource_scatter_update(handle, [0],
+                                                      constant_op.constant(
+                                                          "b",
+                                                          dtype=dtypes.string)))
+    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
+    self.assertEqual(
+        compat.as_bytes(self.evaluate(read)[0][0]), compat.as_bytes("b"))
+
   # TODO(alive): get this to work in Eager mode.
   def testGPU(self):
     with self.test_session(use_gpu=True):
index 7cdf11d..c70a4ff 100644 (file)
@@ -38,38 +38,100 @@ def _NumpyAdd(ref, indices, updates):
     ref[indx] += updates[i]
 
 
+def _NumpyAddScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] += update
+
+
 def _NumpySub(ref, indices, updates):
   for i, indx in np.ndenumerate(indices):
     ref[indx] -= updates[i]
 
 
+def _NumpySubScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] -= update
+
+
 def _NumpyMul(ref, indices, updates):
   for i, indx in np.ndenumerate(indices):
     ref[indx] *= updates[i]
 
 
+def _NumpyMulScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] *= update
+
+
 def _NumpyDiv(ref, indices, updates):
   for i, indx in np.ndenumerate(indices):
     ref[indx] /= updates[i]
 
 
+def _NumpyDivScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] /= update
+
+
+def _NumpyMin(ref, indices, updates):
+  for i, indx in np.ndenumerate(indices):
+    ref[indx] = np.minimum(ref[indx], updates[i])
+
+
+def _NumpyMinScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] = np.minimum(ref[indx], update)
+
+
+def _NumpyMax(ref, indices, updates):
+  for i, indx in np.ndenumerate(indices):
+    ref[indx] = np.maximum(ref[indx], updates[i])
+
+
+def _NumpyMaxScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] = np.maximum(ref[indx], update)
+
+
 def _NumpyUpdate(ref, indices, updates):
   for i, indx in np.ndenumerate(indices):
     ref[indx] = updates[i]
 
 
+def _NumpyUpdateScalar(ref, indices, update):
+  for _, indx in np.ndenumerate(indices):
+    ref[indx] = update
+
+
 _TF_OPS_TO_NUMPY = {
     state_ops.scatter_update: _NumpyUpdate,
     state_ops.scatter_add: _NumpyAdd,
     state_ops.scatter_sub: _NumpySub,
     state_ops.scatter_mul: _NumpyMul,
     state_ops.scatter_div: _NumpyDiv,
+    state_ops.scatter_min: _NumpyMin,
+    state_ops.scatter_max: _NumpyMax,
+}
+
+_TF_OPS_TO_NUMPY_SCALAR = {
+    state_ops.scatter_update: _NumpyUpdateScalar,
+    state_ops.scatter_add: _NumpyAddScalar,
+    state_ops.scatter_sub: _NumpySubScalar,
+    state_ops.scatter_mul: _NumpyMulScalar,
+    state_ops.scatter_div: _NumpyDivScalar,
+    state_ops.scatter_min: _NumpyMinScalar,
+    state_ops.scatter_max: _NumpyMaxScalar,
 }
 
 
 class ScatterTest(test.TestCase):
 
-  def _VariableRankTest(self, tf_scatter, vtype, itype, repeat_indices=False):
+  def _VariableRankTest(self,
+                        tf_scatter,
+                        vtype,
+                        itype,
+                        repeat_indices=False,
+                        updates_are_scalar=False):
     np.random.seed(8)
     with self.test_session(use_gpu=True):
       for indices_shape in (), (2,), (3, 7), (3, 4, 7):
@@ -89,8 +151,11 @@ class ScatterTest(test.TestCase):
                                   indices[np.random.randint(size // 2)])
             np.random.shuffle(indices)
           indices = indices.reshape(indices_shape)
-          updates = _AsType(
-              np.random.randn(*(indices_shape + extra_shape)), vtype)
+          if updates_are_scalar:
+            updates = _AsType(np.random.randn(), vtype)
+          else:
+            updates = _AsType(
+                np.random.randn(*(indices_shape + extra_shape)), vtype)
 
           # Clips small values to avoid division by zero.
           def clip_small_values(x):
@@ -101,7 +166,10 @@ class ScatterTest(test.TestCase):
 
           # Scatter via numpy
           new = old.copy()
-          np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
+          if updates_are_scalar:
+            np_scatter = _TF_OPS_TO_NUMPY_SCALAR[tf_scatter]
+          else:
+            np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
           np_scatter(new, indices, updates)
           # Scatter via tensorflow
           ref = variables.Variable(old)
@@ -109,25 +177,35 @@ class ScatterTest(test.TestCase):
           tf_scatter(ref, indices, updates).eval()
           self.assertAllClose(ref.eval(), new)
 
-  def _VariableRankTests(self, tf_scatter, repeat_indices=False):
+  def _VariableRankTests(self,
+                         tf_scatter,
+                         repeat_indices=False,
+                         updates_are_scalar=False):
     for vtype in (np.float32, np.float64):
       for itype in (np.int32, np.int64):
-        self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices)
+        self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
+                               updates_are_scalar)
 
   def testVariableRankUpdate(self):
-    self._VariableRankTests(state_ops.scatter_update)
+    self._VariableRankTests(state_ops.scatter_update, False)
 
   def testVariableRankAdd(self):
-    self._VariableRankTests(state_ops.scatter_add)
+    self._VariableRankTests(state_ops.scatter_add, False)
 
   def testVariableRankSub(self):
-    self._VariableRankTests(state_ops.scatter_sub)
+    self._VariableRankTests(state_ops.scatter_sub, False)
 
   def testVariableRankMul(self):
-    self._VariableRankTests(state_ops.scatter_mul)
+    self._VariableRankTests(state_ops.scatter_mul, False)
 
   def testVariableRankDiv(self):
-    self._VariableRankTests(state_ops.scatter_div)
+    self._VariableRankTests(state_ops.scatter_div, False)
+
+  def testVariableRankMin(self):
+    self._VariableRankTests(state_ops.scatter_min, False)
+
+  def testVariableRankMax(self):
+    self._VariableRankTests(state_ops.scatter_max, False)
 
   def testRepeatIndicesAdd(self):
     self._VariableRankTests(state_ops.scatter_add, True)
@@ -141,6 +219,51 @@ class ScatterTest(test.TestCase):
   def testRepeatIndicesDiv(self):
     self._VariableRankTests(state_ops.scatter_div, True)
 
+  def testRepeatIndicesMin(self):
+    self._VariableRankTests(state_ops.scatter_min, True)
+
+  def testRepeatIndicesMax(self):
+    self._VariableRankTests(state_ops.scatter_max, True)
+
+  def testVariableRankUpdateScalar(self):
+    self._VariableRankTests(state_ops.scatter_update, False, True)
+
+  def testVariableRankAddScalar(self):
+    self._VariableRankTests(state_ops.scatter_add, False, True)
+
+  def testVariableRankSubScalar(self):
+    self._VariableRankTests(state_ops.scatter_sub, False, True)
+
+  def testVariableRankMulScalar(self):
+    self._VariableRankTests(state_ops.scatter_mul, False, True)
+
+  def testVariableRankDivScalar(self):
+    self._VariableRankTests(state_ops.scatter_div, False, True)
+
+  def testVariableRankMinScalar(self):
+    self._VariableRankTests(state_ops.scatter_min, False, True)
+
+  def testVariableRankMaxScalar(self):
+    self._VariableRankTests(state_ops.scatter_max, False, True)
+
+  def testRepeatIndicesAddScalar(self):
+    self._VariableRankTests(state_ops.scatter_add, True, True)
+
+  def testRepeatIndicesSubScalar(self):
+    self._VariableRankTests(state_ops.scatter_sub, True, True)
+
+  def testRepeatIndicesMulScalar(self):
+    self._VariableRankTests(state_ops.scatter_mul, True, True)
+
+  def testRepeatIndicesDivScalar(self):
+    self._VariableRankTests(state_ops.scatter_div, True, True)
+
+  def testRepeatIndicesMinScalar(self):
+    self._VariableRankTests(state_ops.scatter_min, True, True)
+
+  def testRepeatIndicesMaxScalar(self):
+    self._VariableRankTests(state_ops.scatter_max, True, True)
+
   def testBooleanScatterUpdate(self):
     if not test.is_gpu_available():
       with self.test_session(use_gpu=False) as session:
index 230b7ef..e90ff07 100644 (file)
@@ -80,6 +80,8 @@ from tensorflow.python.ops.state_ops import scatter_add
 from tensorflow.python.ops.state_ops import scatter_div
 from tensorflow.python.ops.state_ops import scatter_mul
 from tensorflow.python.ops.state_ops import scatter_sub
+from tensorflow.python.ops.state_ops import scatter_min
+from tensorflow.python.ops.state_ops import scatter_max
 from tensorflow.python.ops.state_ops import scatter_update
 from tensorflow.python.ops.state_ops import scatter_nd_add
 from tensorflow.python.ops.state_ops import scatter_nd_sub
index c3ad583..01fc318 100644 (file)
@@ -63,6 +63,8 @@
 @@scatter_nd_update
 @@scatter_sub
 @@scatter_update
+@@scatter_min
+@@scatter_max
 @@sparse_mask
 @@tables_initializer
 @@trainable_variables
index 55b82dd..937044a 100644 (file)
@@ -1689,6 +1689,14 @@ tf_module {
     argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
   member_method {
+    name: "scatter_max"
+    argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+  }
+  member_method {
+    name: "scatter_min"
+    argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+  }
+  member_method {
     name: "scatter_mul"
     argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }