From b4572668b48b746802580263dafa12762d202379 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 27 Feb 2019 11:39:37 -0800 Subject: [PATCH] Add sparse gradient option to `gather` operation (#17182) Summary: This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so. I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required. Motivating example: For this commonly used label smoothing loss function ``` def label_smoothing_opt(x, target): padding_idx = 0 smoothing = 0.1 logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32) pad_mask = (target == padding_idx) ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1) smooth_loss = logprobs.mean(dim=-1) loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss loss.masked_fill_(pad_mask, 0) return loss.sum() ``` backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required. Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now! cc gchanan apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/17182 Differential Revision: D14158431 Pulled By: gchanan fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade --- aten/src/ATen/core/Tensor.h | 2 +- aten/src/ATen/core/TensorMethods.h | 4 ++-- aten/src/ATen/core/Type.h | 2 +- aten/src/ATen/native/Indexing.cpp | 20 ++++++++++++++++++ aten/src/ATen/native/LegacyDefinitions.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 6 ++++-- test/test_autograd.py | 33 ++++++++++++++++++++++++++++++ tools/autograd/derivatives.yaml | 4 ++-- tools/autograd/templates/Functions.cpp | 1 + torch/_torch_docs.py | 3 ++- torch/csrc/autograd/input_buffer.cpp | 14 +++++++++++-- torch/csrc/jit/passes/shape_analysis.cpp | 2 +- 12 files changed, 81 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 0682f69..699ea46 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -664,7 +664,7 @@ class CAFFE2_API Tensor { Tensor index_select(int64_t dim, const Tensor & index) const; Tensor masked_select(const Tensor & mask) const; Tensor nonzero() const; - Tensor gather(int64_t dim, const Tensor & index) const; + Tensor gather(int64_t dim, const Tensor & index, bool sparse_grad=false) const; Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; std::tuple gels(const Tensor & A) const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 026b2c3..bf64977 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1126,8 +1126,8 @@ inline Tensor Tensor::masked_select(const Tensor & mask) const { inline Tensor Tensor::nonzero() const { return type().nonzero(*this); } -inline Tensor Tensor::gather(int64_t dim, const Tensor & index) const { - return type().gather(*this, dim, index); +inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const { + return type().gather(*this, dim, index, sparse_grad); } inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { return type().addcmul(*this, tensor1, tensor2, value); diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 1976eeb..1e9edbd 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -564,7 +564,7 @@ struct CAFFE2_API Type { virtual Tensor index_select(const Tensor & self, int64_t dim, const Tensor & index) const = 0; virtual Tensor masked_select(const Tensor & self, const Tensor & mask) const = 0; virtual Tensor nonzero(const Tensor & self) const = 0; - virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index) const = 0; + virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) const = 0; virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0; virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0; virtual std::tuple gels(const Tensor & self, const Tensor & A) const = 0; diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index e00dd88..154d823 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -551,4 +551,24 @@ Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & sour return _self.clone().masked_fill_(mask, source); } +Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){ +// special case scalar input and/or index + if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes()); + if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes()); + Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong)); + int64_t n_above = grad.numel(); + int64_t n_below = 1; + if (dim < 0) dim += self.ndimension(); + for (int i=0; i Tensor(a!) +- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True -- func: gather(Tensor self, int dim, Tensor index) -> Tensor +- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor matches_jit_signature: True variants: method, function +- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True diff --git a/test/test_autograd.py b/test/test_autograd.py index f747acd..799c9c0 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1565,6 +1565,39 @@ class TestAutograd(TestCase): gradcheck(ctc_after_softmax, [x]) + def _test_sparse_gather(self, size_x, size_ind, dim): + x = torch.randn(size_x, requires_grad=True) + if len(size_ind) > 0 and len(size_x) > 0: + ind = torch.randint(x.size(dim), size_ind) + else: + ind = torch.zeros(size_ind, dtype=torch.int64) + out = torch.gather(x, dim, ind, sparse_grad=False) + grad = torch.rand_like(out) + out.backward(grad) + grad_dense = x.grad.clone() + x.grad = None + out = torch.gather(x, dim, ind, sparse_grad=True) + out.backward(grad) + self.assertEqual(grad_dense, x.grad.to_dense()) + + def test_sparse_gather_dim0(self): + self._test_sparse_gather((10, 10), (5, 10), 0) + + def test_sparse_gather_dim1(self): + self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1) + + def test_sparse_gather_dim_neg(self): + self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1) + + def test_sparse_gather_ind_scalar(self): + self._test_sparse_gather((10,), (), 0) + + def test_sparse_gather_x_scalar(self): + self._test_sparse_gather((), (2,), 0) + + def test_sparse_gather_both_scalar(self): + self._test_sparse_gather((), (), 0) + def test_gc_in_destructor(self): """ Previously, if a Function destructor triggered a garbage collection, diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7912d66..ee05b63 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -340,8 +340,8 @@ - name: frac(Tensor self) self: grad -- name: gather(Tensor self, int64_t dim, Tensor index) - self: at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad) +- name: gather(Tensor self, int64_t dim, Tensor index, bool sparse_grad) + self: "sparse_grad ? at::_gather_sparse_backward(self, dim, index, grad) : at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)" - name: ge_(Tensor self, Scalar other) self: zeros_like(self) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index d44ec5a..a87ec8b 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -2075,6 +2075,7 @@ Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) { return grad.sparse_mask(at::SparseTensorRef(input)); } + // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads]) Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { auto negated_pad = pad.vec(); diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f5906c1..0b48e70 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1844,7 +1844,7 @@ Example:: add_docstr(torch.gather, r""" -gather(input, dim, index, out=None) -> Tensor +gather(input, dim, index, out=None, sparse_grad=False) -> Tensor Gathers values along an axis specified by `dim`. @@ -1865,6 +1865,7 @@ Args: dim (int): the axis along which to index index (LongTensor): the indices of elements to gather out (Tensor, optional): the destination tensor + sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. Example:: diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index d98ad88..5322bee 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -22,10 +22,20 @@ void InputBuffer::add(size_t pos, Variable var) { } else { at::OptionalDeviceGuard device_guard(device_of(var)); // ATen doesn't route sparse additions correctly... + // do dense + sparse in-place if possible if (old_var.is_sparse()) { - buffer[pos] = var + old_var; +//storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification + if (!var.is_sparse() && var.is_contiguous() && var.storage().use_count() == 1) { + buffer[pos] = var.add_(old_var); + } else { + buffer[pos] = var + old_var; + } } else { - buffer[pos] = old_var + var; + if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && old_var.storage().use_count() == 1) { + buffer[pos] = old_var.add_(var); + } else { + buffer[pos] = old_var + var; + } } } } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index e4f6e59..1742128 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1214,7 +1214,7 @@ class ShapePropagator { } } else if ( node->matches( - "aten::gather(Tensor self, int dim, Tensor index) -> Tensor")) { + "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) { auto type = input_type(0); auto index_type = input_type(1); // Gather has this annoying edge case where index always needs to match -- 2.7.4