Fix nonzero for scalars on cuda, to_sparse for scalars on cpu/cuda. (#17406)
authorGregory Chanan <gchanan@fb.com>
Mon, 25 Feb 2019 16:08:15 +0000 (08:08 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Feb 2019 16:23:40 +0000 (08:23 -0800)
Summary:
I originally set out to fix to_sparse for scalars, which had some overly restrictive checking (sparse_dim > 0, which is impossible for a scalar).

This fix uncovered an issue with nonzero: it didn't properly return a size (z, 0) tensor for an input scalar, where z is the number of nonzero elements (i.e. 0 or 1).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17406

Differential Revision: D14185393

Pulled By: gchanan

fbshipit-source-id: f37a6e1e3773fd9cbf69eeca7fdebb3caa192a19

aten/src/ATen/native/sparse/SparseTensor.cpp
aten/src/THC/generic/THCTensorMath.cu
test/test_sparse.py
test/test_torch.py

index 2b35020..f89dce3 100644 (file)
@@ -284,14 +284,16 @@ SparseTensor dense_to_sparse(const Tensor& self){
 
 SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
   int64_t dims = self.dim();
-  AT_CHECK(sparse_dim > 0, "sparse_dim must be >0");
+  // TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0,
+  // but this would take some work and doesn't seem particularly useful.
+  AT_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0");
   AT_CHECK(sparse_dim <= dims, 
     "sparse_dim must be less than or equal to self.dim()");
   at::TensorOptions sparse_options = self.options().layout(kSparse);
   std::vector<int64_t> sizes = self.sizes().vec();
 
   Tensor nz = self.nonzero().transpose(0, 1);
-  if (nz.numel() == 0) {
+  if (nz.size(1) == 0) {
     return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, sparse_options);
   }
   LongTensor indices;
@@ -303,8 +305,16 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
     indices = indices.contiguous();  // many sparse CUDA kernels require contiguity, see issue #12633
   }
 
-  std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
-  Tensor values = self.index(ix).squeeze(0).clone();
+  Tensor values;
+  if (self.dim() > 0) {
+    std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
+    values = self.index(ix).squeeze(0).clone();
+  } else {
+    AT_ASSERT(nz.sizes().equals({0, 1}));
+    // In this cases, indices is a clone of nz, which is a tensor of shape (0, 1).
+    // Given sparse tensor invariants, values should be shape (1,)
+    values = self.unsqueeze(0).clone();
+  }
 
   Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
   return sparse._coalesced_(true);
index ae7fe26..9976a4b 100644 (file)
@@ -268,10 +268,15 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
   self = THCTensor_(newContiguous)(state, self);
   thrust::device_ptr<scalar_t> self_data(THCTensor_(data)(state, self));
 
-  int num_dim = THCTensor_(nDimensionLegacyNoScalars)(state, self);
+  int num_dim = THCTensor_(nDimension)(state, self);
+  int num_dim_noscalars = std::max<int>(1, num_dim);
   int64_t N = THCTensor_(nElement)(state, self);
 
-  THCudaLongTensor_resize2d(state, tensor, N, num_dim);
+  // this is a little awkward for scalars because we run thrust to count the number of zeros
+  // (which are necessary to get the correct size), but thrust just has an array API, so
+  // we need to basically threat the scalar as a 1-dimensional tensor (array) for
+  // the counting part.
+  THCudaLongTensor_resize2d(state, tensor, N, num_dim_noscalars);
   tensor = THCudaLongTensor_newContiguous(state, tensor);
   thrust::device_ptr<int64_t> tensor_data(THCudaLongTensor_data(state, tensor));
 
@@ -280,7 +285,7 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
 
   typedef thrust::device_ptr<int64_t> Iter;
   strided_range<Iter> strided_tensor(tensor_data,
-                                     tensor_data+N*num_dim, num_dim);
+                                     tensor_data+N*num_dim_noscalars, num_dim_noscalars);
 
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
   cudaStream_t stream = THCState_getCurrentStream(state);
@@ -299,20 +304,22 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
 
   int64_t num_nonzeros = thrust::distance(strided_tensor.begin(), dend);
 
-  int64_t div = 1;
-  for (int dim = num_dim-1; dim >= 0; dim--) {
-    strided_range<Iter> stride_dim(tensor_data+dim,
-                                   tensor_data+N*num_dim, num_dim);
-    thrust::transform(
-#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
-      thrust::cuda::par(thrustAlloc).on(stream),
-#endif
-      strided_tensor.begin(),
-      strided_tensor.end(),
-      stride_dim.begin(),
-      idx_functor(div, THTensor_sizeLegacyNoScalars(self, dim))
-    );
-    div *= THTensor_sizeLegacyNoScalars(self, dim);
+  if (num_nonzeros > 0 && num_dim > 0) {
+    int64_t div = 1;
+    for (int dim = num_dim-1; dim >= 0; dim--) {
+      strided_range<Iter> stride_dim(tensor_data+dim,
+                                     tensor_data+N*num_dim, num_dim);
+      thrust::transform(
+  #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
+        thrust::cuda::par(thrustAlloc).on(stream),
+  #endif
+        strided_tensor.begin(),
+        strided_tensor.end(),
+        stride_dim.begin(),
+        idx_functor(div, THTensor_(size)(self, dim))
+      );
+      div *= THTensor_(size)(self, dim);
+    }
   }
 
   THCudaLongTensor_resize2d(state, tensor, num_nonzeros, num_dim);
index 49b93bf..ade1905 100644 (file)
@@ -289,6 +289,7 @@ class TestSparse(TestCase):
         a_coalesced = a.coalesce()
         self.assertTrue(a_coalesced.is_coalesced())
         self.assertEqual(self.value_tensor(12.3), a.to_dense())
+        self.assertEqual(a, a.to_dense().to_sparse())
 
         # tensor with multiple values
         a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1).expand(0, 2), [12.3, 12.3], [])
@@ -297,6 +298,7 @@ class TestSparse(TestCase):
         a_coalesced = a.coalesce()
         self.assertTrue(a_coalesced.is_coalesced())
         self.assertEqual(self.value_tensor(12.3 * 2), a.to_dense())
+        self.assertEqual(a, a.to_dense().to_sparse())
 
         # tensor without value
         a = self.sparse_empty(())
@@ -305,6 +307,7 @@ class TestSparse(TestCase):
         a_coalesced = a.coalesce()
         self.assertTrue(a_coalesced.is_coalesced())
         self.assertEqual(self.value_tensor(0), a.to_dense())
+        self.assertEqual(a, a.to_dense().to_sparse())
 
     def test_shared(self):
         i = self.index_tensor([[2]])
index 058941d..6fc6c25 100644 (file)
@@ -8190,6 +8190,14 @@ class _TestTorchMixin(object):
             self.assertEqual(0, y.numel())
             self.assertEqual(torch.Size([0, 5]), y.shape)
 
+            x = torch.tensor(0.5, device=device)
+            y = torch.nonzero(x)
+            self.assertEqual(torch.Size([1, 0]), y.shape)
+
+            x = torch.zeros((), device=device)
+            y = torch.nonzero(x)
+            self.assertEqual(torch.Size([0, 0]), y.shape)
+
     def test_deepcopy(self):
         from copy import deepcopy
         a = torch.randn(5, 5)