Fix incorrect sparse add behavior when the sparse tensor has non-contiguous values...
authorWill Feng <willfeng@fb.com>
Sat, 23 Mar 2019 02:25:58 +0000 (19:25 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 23 Mar 2019 02:35:14 +0000 (19:35 -0700)
Summary:
Currently, this code gives incorrect result:
```python
import torch
indices=torch.tensor([[7, 1, 3]])
values=torch.tensor([[1., 1., 1.],
               [1., 1., 1.],
               [1., 1., 1.]])
x = torch.sparse_coo_tensor(indices, values, size=(10, 3))
values=torch.tensor(1.).expand(3, 3)
y = torch.sparse_coo_tensor(indices, values, size=(10, 3))
z = x + y

tensor(indices=tensor([[7, 1, 3]]),
       values=tensor([[2., 1., 1.],
                      [1., 1., 1.],
                      [1., 1., 1.]]),
       size=(10, 3), nnz=3, layout=torch.sparse_coo)
```

This PR fixes the bug by adding special handling for sparse tensors with non-contiguous values in the addition function (specifically, by cat'ing the indices and values together).

This PR closes https://github.com/pytorch/pytorch/issues/17950 and https://github.com/pytorch/pytorch/issues/17919.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18179

Reviewed By: ezyang

Differential Revision: D14569591

Pulled By: yf225

fbshipit-source-id: f5a14c4a31337fc95eab64596212066b4fb18b1a

aten/src/ATen/native/sparse/SparseTensorMath.cpp
test/test_nn.py
test/test_sparse.py

index 6e31b23..e7fc355 100644 (file)
@@ -211,77 +211,95 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S
   Tensor t_values = t._values();
   LongTensor src_indices = src._indices();
   Tensor s_values = src._values();
-  LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
-  Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
   r.resize_as_(src);
-  get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
 
-  int64_t blockSize = r_values.stride(0);
-  int64_t cmp, d;
-  int64_t r_i = 0, t_i = 0, s_i = 0;
+  if (s_values.is_contiguous() && t_values.is_contiguous()) {
+    LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
+    Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
+    get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
 
-  // NB: relies on nnz tests above
-  auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
-  auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
-  auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
+    int64_t blockSize = r_values.stride(0);
+    int64_t cmp, d;
+    int64_t r_i = 0, t_i = 0, s_i = 0;
 
-  AT_DISPATCH_ALL_TYPES(
-      t_values.scalar_type(), "cadd_sparse", [&] {
-        scalar_t* t_values_ptr = t_values.data<scalar_t>();
-        scalar_t* s_values_ptr = s_values.data<scalar_t>();
-        scalar_t* r_values_ptr = r_values.data<scalar_t>();
-        scalar_t cast_value = value.to<scalar_t>();
-        while (t_i < t_nnz || s_i < s_nnz) {
-          if (t_i >= t_nnz) {
-            cmp = -1;
-          } else if (s_i >= s_nnz) {
-            cmp = 1;
-          } else {
-            cmp = 0;
-            for (d = 0; d < sparse_dim; d++) {
-              if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
-                cmp = 1;
-                break;
-              }
-              if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
-                cmp = -1;
-                break;
+    // NB: relies on nnz tests above
+    auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
+    auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
+    auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
+
+    AT_DISPATCH_ALL_TYPES(
+        t_values.scalar_type(), "cadd_sparse", [&] {
+          scalar_t* t_values_ptr = t_values.data<scalar_t>();
+          scalar_t* s_values_ptr = s_values.data<scalar_t>();
+          scalar_t* r_values_ptr = r_values.data<scalar_t>();
+          scalar_t cast_value = value.to<scalar_t>();
+          while (t_i < t_nnz || s_i < s_nnz) {
+            if (t_i >= t_nnz) {
+              cmp = -1;
+            } else if (s_i >= s_nnz) {
+              cmp = 1;
+            } else {
+              cmp = 0;
+              for (d = 0; d < sparse_dim; d++) {
+                if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
+                  cmp = 1;
+                  break;
+                }
+                if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
+                  cmp = -1;
+                  break;
+                }
               }
             }
-          }
-          if (cmp >= 0) {
-            for (d = 0; d < sparse_dim; d++) {
-              r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
-            }
-            if (t_values.numel() > 0) {  // We add all elements from t_values to r_values only if t_values is not an empty tensor
-              THBlas_axpy<scalar_t>(blockSize, 1,
-                t_values_ptr + t_i * blockSize, 1,
-                r_values_ptr + r_i * blockSize, 1);
-            }
-            t_i++;
-          }
-          if (cmp <= 0) {
-            for (d = 0; d < sparse_dim; d++) {
-              r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i];
+            if (cmp >= 0) {
+              for (d = 0; d < sparse_dim; d++) {
+                r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
+              }
+              if (t_values.numel() > 0) {  // We add all elements from t_values to r_values only if t_values is not an empty tensor
+                THBlas_axpy<scalar_t>(blockSize, 1,
+                  t_values_ptr + t_i * blockSize, 1,
+                  r_values_ptr + r_i * blockSize, 1);
+              }
+              t_i++;
             }
-            if (s_values.numel() > 0) {  // We add all elements from s_values to r_values only if s_values is not an empty tensor
-              THBlas_axpy<scalar_t>(blockSize, cast_value,
-                s_values_ptr + s_i * blockSize, 1,
-                r_values_ptr + r_i * blockSize, 1);
+            if (cmp <= 0) {
+              for (d = 0; d < sparse_dim; d++) {
+                r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i];
+              }
+              if (s_values.numel() > 0) {  // We add all elements from s_values to r_values only if s_values is not an empty tensor
+                THBlas_axpy<scalar_t>(blockSize, cast_value,
+                  s_values_ptr + s_i * blockSize, 1,
+                  r_values_ptr + r_i * blockSize, 1);
+              }
+              s_i++;
             }
-            s_i++;
+            r_i++;
           }
-          r_i++;
         }
-      }
-  );
+    );
 
-  get_sparse_impl(r)->set_nnz_and_narrow(r_i);
-  // TODO: I think it may be possible to track inside the loop and
-  // detect when we are uncoalesced (e.g., by observing that an
-  // index goes backwards) which may be more precise than using the
-  // coalesced flag here.  But this is easy.
-  return r._coalesced_(t_coalesced && s_coalesced);
+    get_sparse_impl(r)->set_nnz_and_narrow(r_i);
+    // TODO: I think it may be possible to track inside the loop and
+    // detect when we are uncoalesced (e.g., by observing that an
+    // index goes backwards) which may be more precise than using the
+    // coalesced flag here.  But this is easy.
+    return r._coalesced_(t_coalesced && s_coalesced);
+  } else {
+    // If `t` or `src` contains non-contiguous `values`, `THBlas_axpy` doesn't work
+    // and we concat the indices and values tensors instead.
+    AT_DISPATCH_ALL_TYPES(
+      s_values.scalar_type(), "add_out_sparse_cuda", [&] {
+          if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
+            s_values = s_values.mul(value);
+          }
+        });
+
+    LongTensor r_indices = at::cat({t_indices, src_indices}, 1);
+    Tensor r_values = at::cat({t_values, s_values}, 0);
+    alias_into_sparse(r, r_indices, r_values);
+
+    return r;
+  }
 }
 
 // --------------------------------------------------------------------
index d8651ab..897255c 100644 (file)
@@ -2052,6 +2052,25 @@ class TestNN(NNTestCase):
         self.assertTrue(embedding.weight.grad.is_sparse)
         self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape)
 
+    def test_embedding_sparse_backward(self):
+        embedding = nn.Embedding(10, 3, sparse=True)
+        embedding.zero_grad()
+        embedding(torch.LongTensor([7, 1, 3])).sum().backward()
+        self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3]]))
+        self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(3, 3))
+
+        embedding.zero_grad()
+        embedding(torch.LongTensor([7, 1, 3])).sum().backward()
+        embedding(torch.LongTensor([7, 1, 3])).sum().backward()
+        self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 7, 1, 3]]))
+        self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
+
+        embedding.zero_grad()
+        embedding(torch.LongTensor([7, 1, 3])).sum().backward()
+        embedding(torch.LongTensor([8, 1, 3])).sum().backward()
+        self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]]))
+        self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3))
+
     def test_embedding_padding_idx(self):
         embedding = nn.Embedding(10, 20, padding_idx=0)
         input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))
index 6f7a10b..fd1db27 100644 (file)
@@ -1167,6 +1167,15 @@ class TestSparse(TestCase):
         test_shape([3, 4], [1, 4], [4, 4, 4], [3, 4, 4])
         test_shape([3, 4, 0], [1, 4], [4, 4, 4, 0], [3, 4, 4, 0])
 
+    def test_add_noncontiguous(self):
+        indices = self.index_tensor([[1, 2], [0, 2]])
+        values = self.value_tensor([1.]).expand(2, 3, 4, 5)
+        x = self.sparse_tensor(indices, values)
+        assert not x._values().is_contiguous()
+        y = x + x
+        expected = self.safeToDense(x) + self.safeToDense(x)
+        self.assertEqual(self.safeToDense(y), expected)
+
     def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None):
         shape = shape_i + (shape_v or [])
         x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape)