From 15a55b86ed7407467d799e058883521a5ab7e7a4 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Mon, 25 Feb 2019 08:08:15 -0800 Subject: [PATCH] Fix nonzero for scalars on cuda, to_sparse for scalars on cpu/cuda. (#17406) Summary: I originally set out to fix to_sparse for scalars, which had some overly restrictive checking (sparse_dim > 0, which is impossible for a scalar). This fix uncovered an issue with nonzero: it didn't properly return a size (z, 0) tensor for an input scalar, where z is the number of nonzero elements (i.e. 0 or 1). Pull Request resolved: https://github.com/pytorch/pytorch/pull/17406 Differential Revision: D14185393 Pulled By: gchanan fbshipit-source-id: f37a6e1e3773fd9cbf69eeca7fdebb3caa192a19 --- aten/src/ATen/native/sparse/SparseTensor.cpp | 18 +++++++++--- aten/src/THC/generic/THCTensorMath.cu | 41 ++++++++++++++++------------ test/test_sparse.py | 3 ++ test/test_torch.py | 8 ++++++ 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 2b35020..f89dce3 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -284,14 +284,16 @@ SparseTensor dense_to_sparse(const Tensor& self){ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ int64_t dims = self.dim(); - AT_CHECK(sparse_dim > 0, "sparse_dim must be >0"); + // TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0, + // but this would take some work and doesn't seem particularly useful. + AT_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0"); AT_CHECK(sparse_dim <= dims, "sparse_dim must be less than or equal to self.dim()"); at::TensorOptions sparse_options = self.options().layout(kSparse); std::vector sizes = self.sizes().vec(); Tensor nz = self.nonzero().transpose(0, 1); - if (nz.numel() == 0) { + if (nz.size(1) == 0) { return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, sparse_options); } LongTensor indices; @@ -303,8 +305,16 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){ indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633 } - std::vector ix = indices.chunk(indices.size(0), 0); - Tensor values = self.index(ix).squeeze(0).clone(); + Tensor values; + if (self.dim() > 0) { + std::vector ix = indices.chunk(indices.size(0), 0); + values = self.index(ix).squeeze(0).clone(); + } else { + AT_ASSERT(nz.sizes().equals({0, 1})); + // In this cases, indices is a clone of nz, which is a tensor of shape (0, 1). + // Given sparse tensor invariants, values should be shape (1,) + values = self.unsqueeze(0).clone(); + } Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options); return sparse._coalesced_(true); diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu index ae7fe26..9976a4b 100644 --- a/aten/src/THC/generic/THCTensorMath.cu +++ b/aten/src/THC/generic/THCTensorMath.cu @@ -268,10 +268,15 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, self = THCTensor_(newContiguous)(state, self); thrust::device_ptr self_data(THCTensor_(data)(state, self)); - int num_dim = THCTensor_(nDimensionLegacyNoScalars)(state, self); + int num_dim = THCTensor_(nDimension)(state, self); + int num_dim_noscalars = std::max(1, num_dim); int64_t N = THCTensor_(nElement)(state, self); - THCudaLongTensor_resize2d(state, tensor, N, num_dim); + // this is a little awkward for scalars because we run thrust to count the number of zeros + // (which are necessary to get the correct size), but thrust just has an array API, so + // we need to basically threat the scalar as a 1-dimensional tensor (array) for + // the counting part. + THCudaLongTensor_resize2d(state, tensor, N, num_dim_noscalars); tensor = THCudaLongTensor_newContiguous(state, tensor); thrust::device_ptr tensor_data(THCudaLongTensor_data(state, tensor)); @@ -280,7 +285,7 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, typedef thrust::device_ptr Iter; strided_range strided_tensor(tensor_data, - tensor_data+N*num_dim, num_dim); + tensor_data+N*num_dim_noscalars, num_dim_noscalars); #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ cudaStream_t stream = THCState_getCurrentStream(state); @@ -299,20 +304,22 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, int64_t num_nonzeros = thrust::distance(strided_tensor.begin(), dend); - int64_t div = 1; - for (int dim = num_dim-1; dim >= 0; dim--) { - strided_range stride_dim(tensor_data+dim, - tensor_data+N*num_dim, num_dim); - thrust::transform( -#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(stream), -#endif - strided_tensor.begin(), - strided_tensor.end(), - stride_dim.begin(), - idx_functor(div, THTensor_sizeLegacyNoScalars(self, dim)) - ); - div *= THTensor_sizeLegacyNoScalars(self, dim); + if (num_nonzeros > 0 && num_dim > 0) { + int64_t div = 1; + for (int dim = num_dim-1; dim >= 0; dim--) { + strided_range stride_dim(tensor_data+dim, + tensor_data+N*num_dim, num_dim); + thrust::transform( + #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ + thrust::cuda::par(thrustAlloc).on(stream), + #endif + strided_tensor.begin(), + strided_tensor.end(), + stride_dim.begin(), + idx_functor(div, THTensor_(size)(self, dim)) + ); + div *= THTensor_(size)(self, dim); + } } THCudaLongTensor_resize2d(state, tensor, num_nonzeros, num_dim); diff --git a/test/test_sparse.py b/test/test_sparse.py index 49b93bf..ade1905 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -289,6 +289,7 @@ class TestSparse(TestCase): a_coalesced = a.coalesce() self.assertTrue(a_coalesced.is_coalesced()) self.assertEqual(self.value_tensor(12.3), a.to_dense()) + self.assertEqual(a, a.to_dense().to_sparse()) # tensor with multiple values a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1).expand(0, 2), [12.3, 12.3], []) @@ -297,6 +298,7 @@ class TestSparse(TestCase): a_coalesced = a.coalesce() self.assertTrue(a_coalesced.is_coalesced()) self.assertEqual(self.value_tensor(12.3 * 2), a.to_dense()) + self.assertEqual(a, a.to_dense().to_sparse()) # tensor without value a = self.sparse_empty(()) @@ -305,6 +307,7 @@ class TestSparse(TestCase): a_coalesced = a.coalesce() self.assertTrue(a_coalesced.is_coalesced()) self.assertEqual(self.value_tensor(0), a.to_dense()) + self.assertEqual(a, a.to_dense().to_sparse()) def test_shared(self): i = self.index_tensor([[2]]) diff --git a/test/test_torch.py b/test/test_torch.py index 058941d..6fc6c25 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -8190,6 +8190,14 @@ class _TestTorchMixin(object): self.assertEqual(0, y.numel()) self.assertEqual(torch.Size([0, 5]), y.shape) + x = torch.tensor(0.5, device=device) + y = torch.nonzero(x) + self.assertEqual(torch.Size([1, 0]), y.shape) + + x = torch.zeros((), device=device) + y = torch.nonzero(x) + self.assertEqual(torch.Size([0, 0]), y.shape) + def test_deepcopy(self): from copy import deepcopy a = torch.randn(5, 5) -- 2.7.4