Add sparse gradient option to `gather` operation (#17182)
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Wed, 27 Feb 2019 19:39:37 +0000 (11:39 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 27 Feb 2019 19:42:48 +0000 (11:42 -0800)
Summary:
This PR allows `gather` to optionally return sparse gradients, as requested in #16329. It also allows to autograd engine to accumulate sparse gradients in place when it is safe to do so.
I've commented out size.size() check in `SparseTensor.cpp` that also caused #17152, it does not seem to me that check serves a useful purpose, but please correct me if I'm wrong and a better fix is required.
Motivating example:
For this commonly used label smoothing loss function
```
def label_smoothing_opt(x, target):
    padding_idx = 0
    smoothing = 0.1
    logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32)
    pad_mask = (target == padding_idx)
    ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1), sparse = True).squeeze(1)
    smooth_loss = logprobs.mean(dim=-1)
    loss =  (smoothing - 1.0) * ll_loss - smoothing * smooth_loss
    loss.masked_fill_(pad_mask, 0)
    return loss.sum()
```
backward goes from 12.6 ms with dense gather gradients to 7.3 ms with sparse gradients, for 9K tokens x 30K vocab, which is some single percent end-to-end improvement, and also improvement in peak memory required.
Shout-out to core devs: adding python-exposed functions with keyword arguments through native_functions.yaml is very easy now!

cc gchanan apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17182

Differential Revision: D14158431

Pulled By: gchanan

fbshipit-source-id: c8b654611534198025daaf7a634482b3151fbade

12 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/native_functions.yaml
test/test_autograd.py
tools/autograd/derivatives.yaml
tools/autograd/templates/Functions.cpp
torch/_torch_docs.py
torch/csrc/autograd/input_buffer.cpp
torch/csrc/jit/passes/shape_analysis.cpp

index 0682f69..699ea46 100644 (file)
@@ -664,7 +664,7 @@ class CAFFE2_API Tensor {
   Tensor index_select(int64_t dim, const Tensor & index) const;
   Tensor masked_select(const Tensor & mask) const;
   Tensor nonzero() const;
-  Tensor gather(int64_t dim, const Tensor & index) const;
+  Tensor gather(int64_t dim, const Tensor & index, bool sparse_grad=false) const;
   Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
   Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
   std::tuple<Tensor,Tensor> gels(const Tensor & A) const;
index 026b2c3..bf64977 100644 (file)
@@ -1126,8 +1126,8 @@ inline Tensor Tensor::masked_select(const Tensor & mask) const {
 inline Tensor Tensor::nonzero() const {
     return type().nonzero(*this);
 }
-inline Tensor Tensor::gather(int64_t dim, const Tensor & index) const {
-    return type().gather(*this, dim, index);
+inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const {
+    return type().gather(*this, dim, index, sparse_grad);
 }
 inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const {
     return type().addcmul(*this, tensor1, tensor2, value);
index 1976eeb..1e9edbd 100644 (file)
@@ -564,7 +564,7 @@ struct CAFFE2_API Type {
   virtual Tensor index_select(const Tensor & self, int64_t dim, const Tensor & index) const = 0;
   virtual Tensor masked_select(const Tensor & self, const Tensor & mask) const = 0;
   virtual Tensor nonzero(const Tensor & self) const = 0;
-  virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index) const = 0;
+  virtual Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) const = 0;
   virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
   virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
   virtual std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) const = 0;
index e00dd88..154d823 100644 (file)
@@ -551,4 +551,24 @@ Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & sour
   return _self.clone().masked_fill_(mask, source);
 }
 
+Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
+// special case scalar input and/or index
+    if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(at::empty({0,grad.numel()}, index.options()), grad, self.sizes());
+    if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe(index.view({1,1}), grad, self.sizes());
+    Tensor sparse_ind = at::empty({self.ndimension(), grad.numel()}, self.options().dtype(at::kLong));
+    int64_t n_above = grad.numel();
+    int64_t n_below = 1;
+    if (dim < 0) dim += self.ndimension();
+    for (int i=0; i<self.ndimension(); i++) {
+        n_above /= grad.size(i);
+        if (i == dim) {
+            sparse_ind[i] = index.reshape(-1);
+        } else {
+            sparse_ind[i] = at::arange(grad.size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand({grad.size(i), n_above}).reshape(-1).repeat(n_below);
+        }
+        n_below *= grad.size(i);
+    }
+    return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
+}
+
 }} // at::native
index 329680d..d19c31a 100644 (file)
@@ -396,11 +396,11 @@ Tensor nonzero(const Tensor & self) {
   return at::legacy::th::_th_nonzero(self);
 }
 
-Tensor & gather_out(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index) {
+Tensor & gather_out(Tensor & result, const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
   return at::legacy::th::_th_gather_out(result, self, dim, index);
 }
 
-Tensor gather(const Tensor & self, int64_t dim, const Tensor & index) {
+Tensor gather(const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
   return at::legacy::th::_th_gather(self, dim, index);
 }
 
index 6466d6a..5a7536c 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
-- func: gather(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: gather(Tensor self, int dim, Tensor index) -> Tensor
+- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
   matches_jit_signature: True
   variants: method, function
 
+- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor
+
 - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
index f747acd..799c9c0 100644 (file)
@@ -1565,6 +1565,39 @@ class TestAutograd(TestCase):
 
             gradcheck(ctc_after_softmax, [x])
 
+    def _test_sparse_gather(self, size_x, size_ind, dim):
+        x = torch.randn(size_x, requires_grad=True)
+        if len(size_ind) > 0 and len(size_x) > 0:
+            ind = torch.randint(x.size(dim), size_ind)
+        else:
+            ind = torch.zeros(size_ind, dtype=torch.int64)
+        out = torch.gather(x, dim, ind, sparse_grad=False)
+        grad = torch.rand_like(out)
+        out.backward(grad)
+        grad_dense = x.grad.clone()
+        x.grad = None
+        out = torch.gather(x, dim, ind, sparse_grad=True)
+        out.backward(grad)
+        self.assertEqual(grad_dense, x.grad.to_dense())
+
+    def test_sparse_gather_dim0(self):
+        self._test_sparse_gather((10, 10), (5, 10), 0)
+
+    def test_sparse_gather_dim1(self):
+        self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1)
+
+    def test_sparse_gather_dim_neg(self):
+        self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1)
+
+    def test_sparse_gather_ind_scalar(self):
+        self._test_sparse_gather((10,), (), 0)
+
+    def test_sparse_gather_x_scalar(self):
+        self._test_sparse_gather((), (2,), 0)
+
+    def test_sparse_gather_both_scalar(self):
+        self._test_sparse_gather((), (), 0)
+
     def test_gc_in_destructor(self):
         """
         Previously, if a Function destructor triggered a garbage collection,
index 7912d66..ee05b63 100644 (file)
 - name: frac(Tensor self)
   self: grad
 
-- name: gather(Tensor self, int64_t dim, Tensor index)
-  self: at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)
+- name: gather(Tensor self, int64_t dim, Tensor index, bool sparse_grad)
+  self: "sparse_grad ? at::_gather_sparse_backward(self, dim, index, grad) : at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)"
 
 - name: ge_(Tensor self, Scalar other)
   self: zeros_like(self)
index d44ec5a..a87ec8b 100644 (file)
@@ -2075,6 +2075,7 @@ Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) {
   return grad.sparse_mask(at::SparseTensorRef(input));
 }
 
+
 // Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
 Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
   auto negated_pad = pad.vec();
index f5906c1..0b48e70 100644 (file)
@@ -1844,7 +1844,7 @@ Example::
 
 add_docstr(torch.gather,
            r"""
-gather(input, dim, index, out=None) -> Tensor
+gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
 
 Gathers values along an axis specified by `dim`.
 
@@ -1865,6 +1865,7 @@ Args:
     dim (int): the axis along which to index
     index (LongTensor): the indices of elements to gather
     out (Tensor, optional): the destination tensor
+    sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
 
 Example::
 
index d98ad88..5322bee 100644 (file)
@@ -22,10 +22,20 @@ void InputBuffer::add(size_t pos, Variable var) {
   } else {
     at::OptionalDeviceGuard device_guard(device_of(var));
     // ATen doesn't route sparse additions correctly...
+    // do dense + sparse in-place if possible
     if (old_var.is_sparse()) {
-      buffer[pos] = var + old_var;
+//storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification
+      if (!var.is_sparse() && var.is_contiguous() && var.storage().use_count() == 1) {
+          buffer[pos] = var.add_(old_var);
+      } else {
+          buffer[pos] = var + old_var;
+      }
     } else {
-      buffer[pos] = old_var + var;
+      if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && old_var.storage().use_count() == 1) {
+          buffer[pos] = old_var.add_(var);
+      } else {
+          buffer[pos] = old_var + var;
+      }
     }
   }
 }
index e4f6e59..1742128 100644 (file)
@@ -1214,7 +1214,7 @@ class ShapePropagator {
       }
     } else if (
         node->matches(
-            "aten::gather(Tensor self, int dim, Tensor index) -> Tensor")) {
+            "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
       auto type = input_type(0);
       auto index_type = input_type(1);
       // Gather has this annoying edge case where index always needs to match