state_ops.scatter_update(v, [1], [3.0])
self.assertAllEqual([1.0, 3.0], v.numpy())
+ def testScatterAddStateOps(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="add")
+ state_ops.scatter_add(v, [1], [3])
+ self.assertAllEqual([1.0, 5.0], v.numpy())
+
def testScatterUpdateCast(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
use_locking, name)]):
return ref.read_value()
+
+
+@tf_export("scatter_add")
+def scatter_add(ref, indices, updates, use_locking=False, name=None):
+ # pylint: disable=line-too-long
+ r"""Adds sparse updates to the variable referenced by `resource`.
+
+ This operation computes
+
+ ```python
+ # Scalar indices
+ ref[indices, ...] += updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] += updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+ ```
+
+ This operation outputs `ref` after the update is done.
+ This makes it easier to chain operations that need to use the updated value.
+ Duplicate entries are handled correctly: if multiple `indices` reference
+ the same location, their contributions add.
+
+ Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
+ </div>
+
+ Args:
+ ref: A `Variable`.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into the first dimension of `ref`.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to store in `ref`.
+ use_locking: An optional `bool`. Defaults to `True`.
+ If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_add(ref, indices, updates,
+ use_locking=use_locking, name=name)
+ return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))