Summary:
This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function
```
def label_smoothing_opt(x, target):
padding_idx = 0
smoothing = 0.1
logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
pad_mask = (target == padding_idx)
ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
smooth_loss = logprobs.mean(dim=-1)
loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
loss.masked_fill_(pad_mask, 0)
return loss.sum()
```
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!
cc gchanan apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17182
Differential Revision:
D14158431
Pulled By: gchanan
fbshipit-source-id:
c8b654611534198025daaf7a634482b3151fbade
Tensor index_select(int64_t dim, const Tensor & index) const;
Tensor masked_select(const Tensor & mask) const;
Tensor nonzero() const;
- Tensor gather(int64_t dim, const Tensor & index) const;
+ Tensor gather(int64_t dim, const Tensor & index, bool sparse_grad=false) const;
Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
std::tuple<Tensor,Tensor> gels(const Tensor & A) const;
inline Tensor Tensor::nonzero() const {
return type().nonzero(*this);
}
-inline Tensor Tensor::gather(int64_t dim, const Tensor & index) const {
- return type().gather(*this, dim, index);
+inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const {
+ return type().gather(*this, dim, index, sparse_grad);
}
inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const {
return type().addcmul(*this, tensor1, tensor2, value);
virtual Tensor index_select(const Tensor & self, int64_t dim, const Tensor & index) const = 0;
virtual Tensor masked_select(const Tensor & self, const Tensor & mask) const = 0;
virtual Tensor nonzero(const Tensor & self) const = 0;
- virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index) const = 0;
+ virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) const = 0;
virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
virtual std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) const = 0;
return _self.clone().masked_fill_(mask, source);
}
+Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
+// special case scalar input and/or index
+ if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
+ if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes());
+ Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong));
+ int64_t n_above = grad.numel();
+ int64_t n_below = 1;
+ if (dim < 0) dim += self.ndimension();
+ for (int i=0; i<self.ndimension(); i++) {
+ n_above /= grad.size(i);
+ if (i == dim) {
+ sparse_ind[i] = index.reshape(-1);
+ } else {
+ sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below);
+ }
+ n_below *= grad.size(i);
+ }
+ return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
+}
+
}} // at::native
return at::legacy::th::_th_nonzero(self);
}
-Tensor & gather_out(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index) {
+Tensor & gather_out(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
return at::legacy::th::_th_gather_out(result, self, dim, index);
}
-Tensor gather(const Tensor & self, int64_t dim, const Tensor & index) {
+Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
return at::legacy::th::_th_gather(self, dim, index);
}
matches_jit_signature: True
variants: method, function
-- func: gather(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
-- func: gather(Tensor self, int dim, Tensor index) -> Tensor
+- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
matches_jit_signature: True
variants: method, function
+- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor
+
- func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
gradcheck(ctc_after_softmax, [x])
+ def _test_sparse_gather(self, size_x, size_ind, dim):
+ x = torch.randn(size_x, requires_grad=True)
+ if len(size_ind) > 0 and len(size_x) > 0:
+ ind = torch.randint(x.size(dim), size_ind)
+ else:
+ ind = torch.zeros(size_ind, dtype=torch.int64)
+ out = torch.gather(x, dim, ind, sparse_grad=False)
+ grad = torch.rand_like(out)
+ out.backward(grad)
+ grad_dense = x.grad.clone()
+ x.grad = None
+ out = torch.gather(x, dim, ind, sparse_grad=True)
+ out.backward(grad)
+ self.assertEqual(grad_dense, x.grad.to_dense())
+
+ def test_sparse_gather_dim0(self):
+ self._test_sparse_gather((10, 10), (5, 10), 0)
+
+ def test_sparse_gather_dim1(self):
+ self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1)
+
+ def test_sparse_gather_dim_neg(self):
+ self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1)
+
+ def test_sparse_gather_ind_scalar(self):
+ self._test_sparse_gather((10,), (), 0)
+
+ def test_sparse_gather_x_scalar(self):
+ self._test_sparse_gather((), (2,), 0)
+
+ def test_sparse_gather_both_scalar(self):
+ self._test_sparse_gather((), (), 0)
+
def test_gc_in_destructor(self):
"""
Previously, if a Function destructor triggered a garbage collection,
- name: frac(Tensor self)
self: grad
-- name: gather(Tensor self, int64_t dim, Tensor index)
- self: at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)
+- name: gather(Tensor self, int64_t dim, Tensor index, bool sparse_grad)
+ self: "sparse_grad ? at::_gather_sparse_backward(self, dim, index, grad) : at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)"
- name: ge_(Tensor self, Scalar other)
self: zeros_like(self)
return grad.sparse_mask(at::SparseTensorRef(input));
}
+
// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
auto negated_pad = pad.vec();
add_docstr(torch.gather,
r"""
-gather(input, dim, index, out=None) -> Tensor
+gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
Gathers values along an axis specified by `dim`.
dim (int): the axis along which to index
index (LongTensor): the indices of elements to gather
out (Tensor, optional): the destination tensor
+ sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
Example::
} else {
at::OptionalDeviceGuard device_guard(device_of(var));
// ATen doesn't route sparse additions correctly...
+ // do dense + sparse in-place if possible
if (old_var.is_sparse()) {
- buffer[pos] = var + old_var;
+//storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification
+ if (!var.is_sparse() && var.is_contiguous() && var.storage().use_count() == 1) {
+ buffer[pos] = var.add_(old_var);
+ } else {
+ buffer[pos] = var + old_var;
+ }
} else {
- buffer[pos] = old_var + var;
+ if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && old_var.storage().use_count() == 1) {
+ buffer[pos] = old_var.add_(var);
+ } else {
+ buffer[pos] = old_var + var;
+ }
}
}
}
}
} else if (
node->matches(
- "aten::gather(Tensor self, int dim, Tensor index) -> Tensor")) {
+ "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
auto type = input_type(0);
auto index_type = input_type(1);
// Gather has this annoying edge case where index always needs to match