SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
int64_t dims = self.dim();
- AT_CHECK(sparse_dim > 0, "sparse_dim must be >0");
+ // TODO: it seems like sparse_dim == 0 could be supported even if self.dim() > 0,
+ // but this would take some work and doesn't seem particularly useful.
+ AT_CHECK(sparse_dim > 0 || self.dim() == 0, "sparse_dim must be >0 if dimensionality > 0");
AT_CHECK(sparse_dim <= dims,
"sparse_dim must be less than or equal to self.dim()");
at::TensorOptions sparse_options = self.options().layout(kSparse);
std::vector<int64_t> sizes = self.sizes().vec();
Tensor nz = self.nonzero().transpose(0, 1);
- if (nz.numel() == 0) {
+ if (nz.size(1) == 0) {
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, sparse_options);
}
LongTensor indices;
indices = indices.contiguous(); // many sparse CUDA kernels require contiguity, see issue #12633
}
- std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
- Tensor values = self.index(ix).squeeze(0).clone();
+ Tensor values;
+ if (self.dim() > 0) {
+ std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
+ values = self.index(ix).squeeze(0).clone();
+ } else {
+ AT_ASSERT(nz.sizes().equals({0, 1}));
+ // In this cases, indices is a clone of nz, which is a tensor of shape (0, 1).
+ // Given sparse tensor invariants, values should be shape (1,)
+ values = self.unsqueeze(0).clone();
+ }
Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
return sparse._coalesced_(true);
self = THCTensor_(newContiguous)(state, self);
thrust::device_ptr<scalar_t> self_data(THCTensor_(data)(state, self));
- int num_dim = THCTensor_(nDimensionLegacyNoScalars)(state, self);
+ int num_dim = THCTensor_(nDimension)(state, self);
+ int num_dim_noscalars = std::max<int>(1, num_dim);
int64_t N = THCTensor_(nElement)(state, self);
- THCudaLongTensor_resize2d(state, tensor, N, num_dim);
+ // this is a little awkward for scalars because we run thrust to count the number of zeros
+ // (which are necessary to get the correct size), but thrust just has an array API, so
+ // we need to basically threat the scalar as a 1-dimensional tensor (array) for
+ // the counting part.
+ THCudaLongTensor_resize2d(state, tensor, N, num_dim_noscalars);
tensor = THCudaLongTensor_newContiguous(state, tensor);
thrust::device_ptr<int64_t> tensor_data(THCudaLongTensor_data(state, tensor));
typedef thrust::device_ptr<int64_t> Iter;
strided_range<Iter> strided_tensor(tensor_data,
- tensor_data+N*num_dim, num_dim);
+ tensor_data+N*num_dim_noscalars, num_dim_noscalars);
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
cudaStream_t stream = THCState_getCurrentStream(state);
int64_t num_nonzeros = thrust::distance(strided_tensor.begin(), dend);
- int64_t div = 1;
- for (int dim = num_dim-1; dim >= 0; dim--) {
- strided_range<Iter> stride_dim(tensor_data+dim,
- tensor_data+N*num_dim, num_dim);
- thrust::transform(
-#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
- thrust::cuda::par(thrustAlloc).on(stream),
-#endif
- strided_tensor.begin(),
- strided_tensor.end(),
- stride_dim.begin(),
- idx_functor(div, THTensor_sizeLegacyNoScalars(self, dim))
- );
- div *= THTensor_sizeLegacyNoScalars(self, dim);
+ if (num_nonzeros > 0 && num_dim > 0) {
+ int64_t div = 1;
+ for (int dim = num_dim-1; dim >= 0; dim--) {
+ strided_range<Iter> stride_dim(tensor_data+dim,
+ tensor_data+N*num_dim, num_dim);
+ thrust::transform(
+ #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
+ thrust::cuda::par(thrustAlloc).on(stream),
+ #endif
+ strided_tensor.begin(),
+ strided_tensor.end(),
+ stride_dim.begin(),
+ idx_functor(div, THTensor_(size)(self, dim))
+ );
+ div *= THTensor_(size)(self, dim);
+ }
}
THCudaLongTensor_resize2d(state, tensor, num_nonzeros, num_dim);
a_coalesced = a.coalesce()
self.assertTrue(a_coalesced.is_coalesced())
self.assertEqual(self.value_tensor(12.3), a.to_dense())
+ self.assertEqual(a, a.to_dense().to_sparse())
# tensor with multiple values
a = self.sparse_tensor(self.index_tensor([]).unsqueeze(1).expand(0, 2), [12.3, 12.3], [])
a_coalesced = a.coalesce()
self.assertTrue(a_coalesced.is_coalesced())
self.assertEqual(self.value_tensor(12.3 * 2), a.to_dense())
+ self.assertEqual(a, a.to_dense().to_sparse())
# tensor without value
a = self.sparse_empty(())
a_coalesced = a.coalesce()
self.assertTrue(a_coalesced.is_coalesced())
self.assertEqual(self.value_tensor(0), a.to_dense())
+ self.assertEqual(a, a.to_dense().to_sparse())
def test_shared(self):
i = self.index_tensor([[2]])
self.assertEqual(0, y.numel())
self.assertEqual(torch.Size([0, 5]), y.shape)
+ x = torch.tensor(0.5, device=device)
+ y = torch.nonzero(x)
+ self.assertEqual(torch.Size([1, 0]), y.shape)
+
+ x = torch.zeros((), device=device)
+ y = torch.nonzero(x)
+ self.assertEqual(torch.Size([0, 0]), y.shape)
+
def test_deepcopy(self):
from copy import deepcopy
a = torch.randn(5, 5)