Expose scatter_add for resource variables.
authorAdria Puigdomenech <adriap@google.com>
Wed, 4 Apr 2018 20:02:27 +0000 (13:02 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 20:04:43 +0000 (13:04 -0700)
PiperOrigin-RevId: 191634030

tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt [new file with mode: 0644]
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/ops/state_ops.py

diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterAdd.pbtxt
new file mode 100644 (file)
index 0000000..4f5b6de
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "ScatterAdd"
+  visibility: HIDDEN
+}
index c31d5a1..edc6326 100644 (file)
@@ -802,6 +802,12 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       state_ops.scatter_update(v, [1], [3.0])
       self.assertAllEqual([1.0, 3.0], v.numpy())
 
+  def testScatterAddStateOps(self):
+    with context.eager_mode():
+      v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add")
+      state_ops.scatter_add(v, [1], [3])
+      self.assertAllEqual([1.0, 5.0], v.numpy())
+
   def testScatterUpdateCast(self):
     with context.eager_mode():
       v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
index 01fc318..f6a11ca 100644 (file)
@@ -423,3 +423,55 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
       ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
       use_locking, name)]):
     return ref.read_value()
+
+
+@tf_export("scatter_add")
+def scatter_add(ref, indices, updates, use_locking=False, name=None):
+  # pylint: disable=line-too-long
+  r"""Adds sparse updates to the variable referenced by `resource`.
+
+  This operation computes
+
+  ```python
+      # Scalar indices
+      ref[indices, ...] += updates[...]
+
+      # Vector indices (for each i)
+      ref[indices[i], ...] += updates[i, ...]
+
+      # High rank indices (for each i, ..., j)
+      ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+  ```
+
+  This operation outputs `ref` after the update is done.
+  This makes it easier to chain operations that need to use the updated value.
+  Duplicate entries are handled correctly: if multiple `indices` reference
+  the same location, their contributions add.
+
+  Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+  <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+  <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+  </div>
+
+  Args:
+    ref: A `Variable`.
+    indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+      A tensor of indices into the first dimension of `ref`.
+    updates: A `Tensor`. Must have the same type as `ref`.
+      A tensor of updated values to store in `ref`.
+    use_locking: An optional `bool`. Defaults to `True`.
+      If True, the assignment will be protected by a lock;
+      otherwise the behavior is undefined, but may exhibit less contention.
+    name: A name for the operation (optional).
+
+  Returns:
+    Same as `ref`.  Returned as a convenience for operations that want
+    to use the updated values after the update is done.
+  """
+  if ref.dtype._is_ref_dtype:
+    return gen_state_ops.scatter_add(ref, indices, updates,
+                                     use_locking=use_locking, name=name)
+  return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add(  # pylint: disable=protected-access
+      ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+      name=name))