don't attempt to multiply by a sparse matrix (#18737)
authorBrennan Vincent <brennan@umanwizard.com>
Fri, 5 Apr 2019 00:18:11 +0000 (17:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 5 Apr 2019 00:24:53 +0000 (17:24 -0700)
Summary:
Tested by running the script in #16562 , and there was no error.

Then:
```
>>> print(mat.grad)
tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]])
```

which is correct.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18737

Differential Revision: D14773078

Pulled By: umanwizard

fbshipit-source-id: 8aa36eb6f6aa104263a467d9ac91d61b3bfd05f5

test/test_sparse.py
tools/autograd/templates/Functions.cpp

index e2a5dfd..5acf650 100644 (file)
@@ -863,8 +863,12 @@ class TestSparse(TestCase):
         test_shape(7, 8, 9, 20, True)
 
     def test_sparse_mm(self):
-        def test_shape(d1, d2, d3, nnz):
-            D = torch.randn(d2, d3, device=self.device).requires_grad_(True)
+        def test_shape(d1, d2, d3, nnz, transposed):
+            if transposed:
+                D = torch.randn(d3, d2,
+                                device=self.device).t_().requires_grad_(True)
+            else:
+                D = torch.randn(d2, d3, device=self.device).requires_grad_(True)
             S = self._gen_sparse(2, nnz, [d1, d2])[0]
             S_dense = S.to_dense().requires_grad_(True)
             S.requires_grad_(True)
@@ -874,7 +878,8 @@ class TestSparse(TestCase):
                 return torch.sparse.mm(S, D)
             gradcheck(fn, (S, D), check_sparse_nnz=True)
 
-        test_shape(7, 8, 9, 20)
+        test_shape(7, 8, 9, 20, False)
+        test_shape(7, 8, 9, 20, True)
 
     @skipIfRocm
     def test_dsmm(self):
index 68012b0..ba3ccbb 100644 (file)
@@ -524,6 +524,17 @@ Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor &
 Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef sizes, IntArrayRef strides, const Scalar & alpha) {
   // if input was column-major, return grad as column-order for efficiency
   if (strides[0] == 1 && strides[1] == sizes[0]) {
+    if (mat1.is_sparse()) {
+      // Since mm(dense, sparse) doesn't exist,
+      // pass a transposed output matrix to the underlying "addmm"
+      // function directly.
+      int64_t out_rows = mat1.size(1);
+      int64_t out_cols = grad.size(1);
+      Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true);
+      Tensor r = at::empty({out_cols, out_rows}, grad.options()).t();
+      at::s_native_addmm_out(r, t, mat1.t(), grad, alpha, 1);
+      return r;
+    }
     return maybe_multiply(grad.t().mm(mat1).t(), alpha);
   } else {
     return maybe_multiply(mat1.t().mm(grad), alpha);