sparse.mm(), reland #14526 (#14661)
authorWei Yang <weiyang@fb.com>
Mon, 3 Dec 2018 18:26:02 +0000 (10:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 18:39:27 +0000 (10:39 -0800)
Summary:
- reland reverted PR #14526 with doc fixes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14661

Differential Revision: D13289047

Pulled By: weiyangfb

fbshipit-source-id: 5b843a11a58b56aeada3af2680a27cf89ecef4d8

aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensorMath.cpp
docs/source/sparse.rst
test/test_sparse.py
tools/autograd/templates/Functions.cpp
torch/sparse/__init__.py

index e428f51..f938473 100644 (file)
 
 - func: mm_out(Tensor result, Tensor self, Tensor mat2) -> Tensor
 
+- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
+
 - func: mode(Tensor self, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor)
   variants: function, method
 
index 414e4d4..1737333 100644 (file)
@@ -593,10 +593,16 @@ Tensor _sparse_addmm(
   Scalar beta,
   Scalar alpha
 ) {
-  AT_CHECK(sparse.is_coalesced(), "_sparse_addmm doesn't support uncoalesced SparseTensor");
   return at::s_native_addmm(t, sparse, dense, beta, alpha);
 }
 
+Tensor _sparse_mm(
+  const SparseTensor& sparse,
+  const Tensor& dense
+) {
+  Tensor t = at::empty({sparse.size(0), dense.size(1)}, dense.options());
+  return at::_sparse_addmm(t, sparse, dense, 0, 1);
+}
 
 // --------------------------------------------------------------------
 // hspmm(SparseTensor mat1, Tensor mat2)
index 1e6afde..b746af7 100644 (file)
@@ -141,4 +141,5 @@ Functions
 ----------------------------------
 
 .. autofunction:: torch.sparse.addmm
+.. autofunction:: torch.sparse.mm
 .. autofunction:: torch.sparse.sum
index 60fc784..5011f5c 100644 (file)
@@ -818,10 +818,28 @@ class TestSparse(TestCase):
             y1.backward()
             y2.backward()
             mask = (S_dense == 0)
+            self.assertTrue(S.grad.is_coalesced())
             self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0))
 
-        if not self.is_uncoalesced:
-            test_shape(7, 8, 9, 20)
+        test_shape(7, 8, 9, 20)
+
+    @skipIfRocm
+    def test_sparse_mm(self):
+        def test_shape(d1, d2, d3, nnz):
+            D = torch.randn(d2, d3, device=self.device).requires_grad_(True)
+            S = self._gen_sparse(2, nnz, [d1, d2])[0]
+            S_dense = S.to_dense().requires_grad_(True)
+            S.requires_grad_(True)
+            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
+            y1 = torch.sparse.mm(S, D).sum()
+            y2 = torch.mm(S_dense, D).sum()
+            y1.backward()
+            y2.backward()
+            mask = (S_dense == 0)
+            self.assertTrue(S.grad.is_coalesced())
+            self.assertEqual(S.grad.to_dense(), S_dense.grad.masked_fill_(mask, 0))
+
+        test_shape(7, 8, 9, 20)
 
     @skipIfRocm
     def test_dsmm(self):
index 137f31d..577bc0c 100644 (file)
@@ -511,8 +511,9 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntList sizes,
   }
 }
 
-Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse, const Tensor& dense, const Scalar& alpha) {
-  AT_ASSERT(sparse.is_sparse());
+Tensor _sparse_addmm_sparse_backward(const Tensor& grad, const Tensor& sparse_, const Tensor& dense, const Scalar& alpha) {
+  AT_ASSERT(sparse_.is_sparse());
+  auto sparse = sparse_.coalesce();
   Tensor grad_sparse = maybe_multiply(grad.mm(dense.t()), alpha);
   return grad_sparse.sparse_mask(at::SparseTensorRef(sparse));
 }
index 07553e9..5edba10 100644 (file)
@@ -3,6 +3,7 @@ import torch
 
 __all__ = [
     'addmm',
+    'mm',
     'sum',
 ]
 
@@ -10,11 +11,13 @@ __all__ = [
 def addmm(mat, mat1, mat2, beta=1, alpha=1):
     r"""
     This function does exact same thing as :func:`torch.addmm` in the forward,
-    except that it supports backward for coalesced sparse matrix `mat1`.
+    except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
+    need to have `sparse_dim = 2`. Note that the gradients of :attr:`mat1` is a
+    coalesced sparse tensor.
 
     Args:
         mat (Tensor): a dense matrix to be added
-        mat1 (Tensor): a sparse matrix to be multiplied
+        mat1 (SparseTensor): a sparse matrix to be multiplied
         mat2 (Tensor): a dense matrix be multiplied
         beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
         alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
@@ -22,6 +25,48 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1):
     return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
 
 
+def mm(mat1, mat2):
+    r"""
+    Performs a matrix multiplication of the sparse matrix :attr:`mat1`
+    and dense matrix :attr:`mat2`. Similar to :func:`torch.mm`, If :attr:`mat1` is a
+    :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
+    :math:`(n \times p)` dense tensor. :attr:`mat1` need to have `sparse_dim = 2`.
+    This function also supports backward for both matrices. Note that the gradients of
+    :attr:`mat1` is a coalesced sparse tensor.
+
+    Args:
+        mat1 (SparseTensor): the first sparse matrix to be multiplied
+        mat2 (Tensor): the second dense matrix to be multiplied
+
+    Example::
+
+        >>> a = torch.randn(2, 3).to_sparse().requires_grad_(True)
+        >>> a
+        tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
+                               [0, 1, 2, 0, 1, 2]]),
+               values=tensor([ 1.5901,  0.0183, -0.6146,  1.8061, -0.0112,  0.6302]),
+               size=(2, 3), nnz=6, layout=torch.sparse_coo, requires_grad=True)
+
+        >>> b = torch.randn(3, 2, requires_grad=True)
+        >>> b
+        tensor([[-0.6479,  0.7874],
+                [-1.2056,  0.5641],
+                [-1.1716, -0.9923]], requires_grad=True)
+
+        >>> y = torch.sparse.mm(a, b)
+        >>> y
+        tensor([[-0.3323,  1.8723],
+                [-1.8951,  0.7904]], grad_fn=<SparseAddmmBackward>)
+        >>> y.sum().backward()
+        >>> a.grad
+        tensor(indices=tensor([[0, 0, 0, 1, 1, 1],
+                               [0, 1, 2, 0, 1, 2]]),
+               values=tensor([ 0.1394, -0.6415, -2.1639,  0.1394, -0.6415, -2.1639]),
+               size=(2, 3), nnz=6, layout=torch.sparse_coo)
+    """
+    return torch._sparse_mm(mat1, mat2)
+
+
 def sum(input, dim=None, dtype=None):
     r"""
     Returns the sum of each row of SparseTensor :attr:`input` in the given