From 7be05b822c01e8db5e23c8a213ba2e83d17b0bd2 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 22 Mar 2019 19:25:58 -0700 Subject: [PATCH] Fix incorrect sparse add behavior when the sparse tensor has non-contiguous values (#18179) Summary: Currently, this code gives incorrect result: ```python import torch indices=torch.tensor([[7, 1, 3]]) values=torch.tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) x = torch.sparse_coo_tensor(indices, values, size=(10, 3)) values=torch.tensor(1.).expand(3, 3) y = torch.sparse_coo_tensor(indices, values, size=(10, 3)) z = x + y tensor(indices=tensor([[7, 1, 3]]), values=tensor([[2., 1., 1.], [1., 1., 1.], [1., 1., 1.]]), size=(10, 3), nnz=3, layout=torch.sparse_coo) ``` This PR fixes the bug by adding special handling for sparse tensors with non-contiguous values in the addition function (specifically, by cat'ing the indices and values together). This PR closes https://github.com/pytorch/pytorch/issues/17950 and https://github.com/pytorch/pytorch/issues/17919. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18179 Reviewed By: ezyang Differential Revision: D14569591 Pulled By: yf225 fbshipit-source-id: f5a14c4a31337fc95eab64596212066b4fb18b1a --- aten/src/ATen/native/sparse/SparseTensorMath.cpp | 138 +++++++++++++---------- test/test_nn.py | 19 ++++ test/test_sparse.py | 9 ++ 3 files changed, 106 insertions(+), 60 deletions(-) diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 6e31b23..e7fc355 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -211,77 +211,95 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S Tensor t_values = t._values(); LongTensor src_indices = src._indices(); Tensor s_values = src._values(); - LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options()); - Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_(); r.resize_as_(src); - get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values); - int64_t blockSize = r_values.stride(0); - int64_t cmp, d; - int64_t r_i = 0, t_i = 0, s_i = 0; + if (s_values.is_contiguous() && t_values.is_contiguous()) { + LongTensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options()); + Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_(); + get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values); - // NB: relies on nnz tests above - auto t_indices_accessor = t_indices.accessor(); - auto r_indices_accessor = r_indices.accessor(); - auto src_indices_accessor = src_indices.accessor(); + int64_t blockSize = r_values.stride(0); + int64_t cmp, d; + int64_t r_i = 0, t_i = 0, s_i = 0; - AT_DISPATCH_ALL_TYPES( - t_values.scalar_type(), "cadd_sparse", [&] { - scalar_t* t_values_ptr = t_values.data(); - scalar_t* s_values_ptr = s_values.data(); - scalar_t* r_values_ptr = r_values.data(); - scalar_t cast_value = value.to(); - while (t_i < t_nnz || s_i < s_nnz) { - if (t_i >= t_nnz) { - cmp = -1; - } else if (s_i >= s_nnz) { - cmp = 1; - } else { - cmp = 0; - for (d = 0; d < sparse_dim; d++) { - if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) { - cmp = 1; - break; - } - if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) { - cmp = -1; - break; + // NB: relies on nnz tests above + auto t_indices_accessor = t_indices.accessor(); + auto r_indices_accessor = r_indices.accessor(); + auto src_indices_accessor = src_indices.accessor(); + + AT_DISPATCH_ALL_TYPES( + t_values.scalar_type(), "cadd_sparse", [&] { + scalar_t* t_values_ptr = t_values.data(); + scalar_t* s_values_ptr = s_values.data(); + scalar_t* r_values_ptr = r_values.data(); + scalar_t cast_value = value.to(); + while (t_i < t_nnz || s_i < s_nnz) { + if (t_i >= t_nnz) { + cmp = -1; + } else if (s_i >= s_nnz) { + cmp = 1; + } else { + cmp = 0; + for (d = 0; d < sparse_dim; d++) { + if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) { + cmp = 1; + break; + } + if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) { + cmp = -1; + break; + } } } - } - if (cmp >= 0) { - for (d = 0; d < sparse_dim; d++) { - r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i]; - } - if (t_values.numel() > 0) { // We add all elements from t_values to r_values only if t_values is not an empty tensor - THBlas_axpy(blockSize, 1, - t_values_ptr + t_i * blockSize, 1, - r_values_ptr + r_i * blockSize, 1); - } - t_i++; - } - if (cmp <= 0) { - for (d = 0; d < sparse_dim; d++) { - r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i]; + if (cmp >= 0) { + for (d = 0; d < sparse_dim; d++) { + r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i]; + } + if (t_values.numel() > 0) { // We add all elements from t_values to r_values only if t_values is not an empty tensor + THBlas_axpy(blockSize, 1, + t_values_ptr + t_i * blockSize, 1, + r_values_ptr + r_i * blockSize, 1); + } + t_i++; } - if (s_values.numel() > 0) { // We add all elements from s_values to r_values only if s_values is not an empty tensor - THBlas_axpy(blockSize, cast_value, - s_values_ptr + s_i * blockSize, 1, - r_values_ptr + r_i * blockSize, 1); + if (cmp <= 0) { + for (d = 0; d < sparse_dim; d++) { + r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i]; + } + if (s_values.numel() > 0) { // We add all elements from s_values to r_values only if s_values is not an empty tensor + THBlas_axpy(blockSize, cast_value, + s_values_ptr + s_i * blockSize, 1, + r_values_ptr + r_i * blockSize, 1); + } + s_i++; } - s_i++; + r_i++; } - r_i++; } - } - ); + ); - get_sparse_impl(r)->set_nnz_and_narrow(r_i); - // TODO: I think it may be possible to track inside the loop and - // detect when we are uncoalesced (e.g., by observing that an - // index goes backwards) which may be more precise than using the - // coalesced flag here. But this is easy. - return r._coalesced_(t_coalesced && s_coalesced); + get_sparse_impl(r)->set_nnz_and_narrow(r_i); + // TODO: I think it may be possible to track inside the loop and + // detect when we are uncoalesced (e.g., by observing that an + // index goes backwards) which may be more precise than using the + // coalesced flag here. But this is easy. + return r._coalesced_(t_coalesced && s_coalesced); + } else { + // If `t` or `src` contains non-contiguous `values`, `THBlas_axpy` doesn't work + // and we concat the indices and values tensors instead. + AT_DISPATCH_ALL_TYPES( + s_values.scalar_type(), "add_out_sparse_cuda", [&] { + if (value.to() != static_cast(1)) { + s_values = s_values.mul(value); + } + }); + + LongTensor r_indices = at::cat({t_indices, src_indices}, 1); + Tensor r_values = at::cat({t_values, s_values}, 0); + alias_into_sparse(r, r_indices, r_values); + + return r; + } } // -------------------------------------------------------------------- diff --git a/test/test_nn.py b/test/test_nn.py index d8651ab..897255c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2052,6 +2052,25 @@ class TestNN(NNTestCase): self.assertTrue(embedding.weight.grad.is_sparse) self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) + def test_embedding_sparse_backward(self): + embedding = nn.Embedding(10, 3, sparse=True) + embedding.zero_grad() + embedding(torch.LongTensor([7, 1, 3])).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3]])) + self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(3, 3)) + + embedding.zero_grad() + embedding(torch.LongTensor([7, 1, 3])).sum().backward() + embedding(torch.LongTensor([7, 1, 3])).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 7, 1, 3]])) + self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) + + embedding.zero_grad() + embedding(torch.LongTensor([7, 1, 3])).sum().backward() + embedding(torch.LongTensor([8, 1, 3])).sum().backward() + self.assertEqual(embedding.weight.grad._indices(), torch.LongTensor([[7, 1, 3, 8, 1, 3]])) + self.assertEqual(embedding.weight.grad._values(), torch.tensor(1.).expand(6, 3)) + def test_embedding_padding_idx(self): embedding = nn.Embedding(10, 20, padding_idx=0) input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])) diff --git a/test/test_sparse.py b/test/test_sparse.py index 6f7a10b..fd1db27 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -1167,6 +1167,15 @@ class TestSparse(TestCase): test_shape([3, 4], [1, 4], [4, 4, 4], [3, 4, 4]) test_shape([3, 4, 0], [1, 4], [4, 4, 4, 0], [3, 4, 4, 0]) + def test_add_noncontiguous(self): + indices = self.index_tensor([[1, 2], [0, 2]]) + values = self.value_tensor([1.]).expand(2, 3, 4, 5) + x = self.sparse_tensor(indices, values) + assert not x._values().is_contiguous() + y = x + x + expected = self.safeToDense(x) + self.safeToDense(x) + self.assertEqual(self.safeToDense(y), expected) + def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None): shape = shape_i + (shape_v or []) x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape) -- 2.7.4