From be7c618fd7fadb9db58aae8bef16638912748cfa Mon Sep 17 00:00:00 2001 From: Wei Yang Date: Wed, 28 Nov 2018 02:16:56 -0800 Subject: [PATCH] torch.sparse.sum() (#12430) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: - to fix #12241 - add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently - this PR depends on #11253, and will need to be updated upon it lands - [x] implement forward - [x] implement backward - performance [benchmark script](https://gist.github.com/weiyangfb/f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py): - sum all dims is fastest for sparse tensor - when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA - CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse - CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims` - optimize CUDA backward for now - using thrust for sort and binary search, but runtime not improved - both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA - improve CPU and CUDA forward kernels (nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense) -- | -- | -- (1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs (1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms (1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs (1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs (1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs (10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms (10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs (10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs (10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs - after improving CPU and CUDA forward kernels - in `(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`. - after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims` (nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense) -- | -- | -- (1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs (1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs (1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms (1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs (1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs (1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs (10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs (10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms (10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs (10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs (10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs - time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`: ``` >>> d = [1000, 1000, 2, 2] >>> nnz = 10000 >>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)), torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz) >>> V = torch.randn(nnz, d[2], d[3]) >>> size = torch.Size(d) >>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda() >>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_() >>> data = S2.clone() >>> S.copy_(S2) >>> y = S * 2 >>> torch.cuda.synchronize() >>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize() 7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/12430 Differential Revision: D12878313 Pulled By: weiyangfb fbshipit-source-id: e16dc7681ba41fdabf4838cf05e491ca9108c6fe --- aten/src/ATen/SparseTensorUtils.h | 29 +++ aten/src/ATen/core/aten_interned_strings.h | 1 + aten/src/ATen/native/native_functions.yaml | 14 ++ aten/src/ATen/native/sparse/SparseTensorMath.cpp | 246 ++++++++++++++++++++- .../native/sparse/cuda/SparseCUDATensorMath.cu | 172 ++++++++++++++ docs/source/sparse.rst | 8 + test/test_sparse.py | 58 +++++ tools/autograd/derivatives.yaml | 3 + torch/sparse/__init__.py | 70 +++++- 9 files changed, 599 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h index 2eed30b..2f12c80 100644 --- a/aten/src/ATen/SparseTensorUtils.h +++ b/aten/src/ATen/SparseTensorUtils.h @@ -109,4 +109,33 @@ inline LongTensor flatten_indices(const Tensor& indices, IntList full_size, bool } } +// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ], +// except this one allows partial flatten: only flatten on specified dims. Note that +// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim. +// Also if input indices is already coalesced, the flattened indices will also be sorted. +// +// args: +// indices: sparse tensor indices +// sizes: sparse tensor sizes +// dims_to_flatten: a list of dim index to flatten +// +// Ex1: +// indices = [[2, 4, 0], +// [3, 1, 3]] +// sizes = [2, 12] +// dims_to_flatten = [0, 1] +// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3] +// +// Ex2: +// dims_to_flatten = [1] +// new_indices = [ 3, 1, 3 ] # uncoalesced +inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){ + LongTensor new_indices = at::zeros({indices.size(1)}, indices.options()); + for (auto d : dims_to_flatten) { + new_indices.mul_(sizes[d]); + new_indices.add_(indices.select(0, d)); + } + return new_indices; +} + }} // namespace at::sparse diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 5f6896d..486d5cf 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -128,6 +128,7 @@ _(aten, _sparse_div_zerodim) \ _(aten, _sparse_mul) \ _(aten, _sparse_mul_scalar) \ _(aten, _sparse_mul_zerodim) \ +_(aten, _sparse_sum) \ _(aten, _sqrt) \ _(aten, _standard_gamma) \ _(aten, _standard_gamma_grad) \ diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fe51e37..7e71f47 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1855,6 +1855,20 @@ SparseCPU: norm_sparse SparseCUDA: norm_sparse +# TODO: reduce signatures down to one when optional args is available +- func: _sparse_sum(Tensor self) -> Tensor + +- func: _sparse_sum(Tensor self, *, ScalarType dtype) -> Tensor + +- func: _sparse_sum(Tensor self, IntList[1] dim) -> Tensor + +- func: _sparse_sum(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor + +- func: _sparse_sum_backward(Tensor grad, Tensor self, IntList dim) -> Tensor + dispatch: + SparseCPU: _sparse_sum_backward_cpu + SparseCUDA: _sparse_sum_backward_cuda + - func: norm(Tensor self, Scalar p=2) -> Tensor variants: function, method diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index ee98ce2..414e4d4 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -4,13 +4,13 @@ #include #include #include +#include "ATen/WrapDimUtilsMulti.h" #include namespace at { namespace native { using namespace at::sparse; - // -------------------------------------------------------------------- // Utility functions // -------------------------------------------------------------------- @@ -814,4 +814,248 @@ Tensor sspaddmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2, return result; } +// -------------------------------------------------------------------- +// sparse.sum() +// +// This implementation calls coalesce() to do the sum reduction on +// sparse dims. Ideally in the future there should be unified reduction function +// for ops like sum, max, and min. +// -------------------------------------------------------------------- +Tensor _sparse_sum(const SparseTensor& input) { + return input.coalesce().values().sum(); +} + +Tensor _sparse_sum(const SparseTensor& input, ScalarType dtype) { + // don't have to do a conversion to the correct dtype first + // just need to setup the accumulator correctly + return input.coalesce().values().sum(dtype); +} + +Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum, ScalarType dtype) { + return at::_sparse_sum(input.to(dtype), dims_to_sum); +} + +Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum) { + AT_CHECK(input._nnz() > 0, "_sparse_sum: sparse tensor input._nnz() == 0, please call torch.sparse.sum(input) instead.") + + const int64_t input_dim = input.dim(); + auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim); + auto dims_to_sum_v = dims_to_sum.vec(); + maybe_wrap_dims(dims_to_sum_v, input_dim); + + LongTensor indices = input._indices(); + Tensor values = input._values(); + IntList sizes = input.sizes(); + const int64_t sparse_dim = input.sparse_dim(); + const int64_t dense_dim = input.dense_dim(); + + auto dims_to_keep_v = std::vector(); + auto dense_dims_to_sum_v = std::vector(); + for (int64_t d = 0; d < input_dim; d++) { + if (dims_to_sum_b[d]) { + if (d >= sparse_dim) dense_dims_to_sum_v.emplace_back(d + 1 - sparse_dim); + } + else { + dims_to_keep_v.emplace_back(d); + } + } + const int64_t sparse_dims_to_sum_size = dims_to_sum_v.size() - dense_dims_to_sum_v.size(); + const bool sum_all_sparse_dim = (sparse_dim == sparse_dims_to_sum_size); + const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0); + + // new values + Tensor new_values; + if (sum_dense_dim) { + new_values = values.sum(dense_dims_to_sum_v); + } + else { + new_values = values.clone(); + } + + if (sum_all_sparse_dim) { + // return a dense tensor if sum over all sparse dims + new_values = new_values.sum(0); + return new_values; + } + else { // !sum_all_sparse_dim + // new indices + LongTensor new_indices; + if (sparse_dims_to_sum_size == 0) { + new_indices = indices.clone(); + } + else { + new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options()); + for (int64_t i = 0; i < dims_to_keep_v.size(); i++) { + int64_t d = dims_to_keep_v[i]; + if (d < sparse_dim) new_indices[i].copy_(indices[d]); + else break; + } + } + + // new size + int64_t new_sparse_dim = new_indices.size(0); + int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim + std::vector new_sizes; + for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]); + if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1); + + // use coalesce() to do sum reduction + SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options()); + new_sparse = new_sparse.coalesce(); + return new_sparse; + } + +} + +// -------------------------------------------------------------------- +// NOTE [ sparse.sum() backward ] +// +// When sum over sparse_dim, backward scatters gradients from grad tensor to input tensor. +// Grad and input need to align indices over sparse_dim that are not summed (given +// input.spares_dim >= grad.sparse_dim). Implementation here compares each pair of +// indices between grad and input. When a matching indices pair (input_i, grad_i) is found, +// copy grad.values[grad_i] -> input_grad.values[input_i]. E.g., +// +// input.sparse_dim = [5, 5] +// input.indices = [[0, 0, 1, 2, 2, 3, 4, 4], +// [1, 4, 4, 0, 1, 3, 2, 4]] +// input.values = [0, 1, 2, 3, 4, 5, 6, 7] +// ... +// sparse.sum(input, [0]) +// backward(...) +// ... +// grad.indices = [[0, 1, 2, 3]] +// grad.values = [1, 2, 0, 4] +// +// # after indices matching +// input grad +// [[0, 1], -> [1] +// [0, 4], -> [ ] +// [1, 4], -> [ ] +// [2, 0], -> [0] +// [2, 1], -> [1] +// [3, 3], -> [3] +// [4, 2], -> [2] +// [4, 4]]) -> [ ] +// +// input_grad.indices = [[0, 0, 1, 2, 2, 3, 4, 4], +// [1, 4, 4, 0, 1, 3, 2, 4]] +// input_grad.values = [2, 0, 0, 1, 2, 4, 0, 0] +// +// Note that we allow input to be uncoalesced in the forward, +// we have to coalesce input at the backward, because grad-of-input +// take the same indices as input, if input is not coalesced, then +// coalescing grad-of-input may add up grad values for a duplicate indices, +// and hence generates a wrong grad-of-input. +// +// Other edge cases: +// - assign zero values to input gradients if cannot find matched indices at grad +// - grad.values might have zeros +// -------------------------------------------------------------------- +Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) { + AT_CHECK(!grad_.is_cuda(), "_sparse_sum_backward_cpu: expected 'grad_' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!input_.is_cuda(), "_sparse_sum_backward_cpu: expected 'input_' to be CPU tensor, but got CUDA tensor"); + + auto input = input_.coalesce(); + const int64_t input_dim = input.dim(); + auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim); + auto dims_to_sum_v = dims_to_sum.vec(); + maybe_wrap_dims(dims_to_sum_v, input_dim); + + LongTensor input_indices = input._indices(); + Tensor input_values = input._values(); + IntList input_sizes = input.sizes(); + const int64_t input_sparse_dim = input.sparse_dim(); + const int64_t input_dense_dim = input.dense_dim(); + const int64_t input_nnz = input._nnz(); + + int64_t sparse_dims_to_sum_size = 0; + auto sparse_dims_to_keep_v = std::vector(); + auto dense_dims_to_sum_v = std::vector(); + for (int64_t d = 0; d < input_dim; d++) { + if (dims_to_sum_b[d]) { + if (d < input_sparse_dim) sparse_dims_to_sum_size ++; + else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim); + } + else { + if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d); + } + } + + const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size); + const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0); + const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0); + + if (sum_all_sparse_dim) { + AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be dense since all sparse dims are summed"); + auto grad_input_values = grad_; + auto expand_size = input_values.sizes().vec(); + if (sum_dense_dim) { + auto dense_expand_size = std::vector(expand_size); + dense_expand_size.erase(dense_expand_size.begin()); + AT_ASSERT(dense_expand_size.size() == (input_values.dim() - 1)); + for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim + grad_input_values = grad_input_values.expand(dense_expand_size); + } + grad_input_values = grad_input_values.expand(expand_size).clone(); + return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype + } + else { + AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense"); + auto grad = grad_.coalesce(); + LongTensor grad_indices = grad._indices(); + Tensor grad_values = grad._values(); + const int64_t grad_sparse_dim = grad.sparse_dim(); + const int64_t grad_nnz = grad._nnz(); + + Tensor grad_values_expand = grad_values; + if (sum_dense_dim) { + auto expand_size = input_values.sizes().vec(); + if (sum_sparse_dim) expand_size[0] = grad_values.size(0); + for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d); + grad_values_expand = grad_values_expand.expand(expand_size).clone(); + } + + Tensor grad_input_values; + if (sum_sparse_dim) { + // see NOTE [ sparse.sum() backward ] + grad_input_values = at::zeros_like(input_values, grad_values.options()); + + // get flatten indices for grad and input + auto grad_sparse_dim_to_keep_v = std::vector(grad_sparse_dim); + std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0); + + auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted + auto grad_indices_1D_accessor = grad_indices_1D.accessor(); + auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v); + auto input_indices_1D_accessor = input_indices_1D.accessor(); + + // binary search to find matching indices + int64_t i; + #pragma omp parallel for private(i) + for (i = 0; i < input_nnz; i++) { + int64_t input_idx = input_indices_1D_accessor[i]; + int64_t l = 0, r = grad_nnz - 1; + while (l <= r) { + int64_t m = l + (r - l) / 2; + if (grad_indices_1D_accessor[m] == input_idx) { + grad_input_values[i].copy_(grad_values_expand[m]); + break; + } + if (grad_indices_1D_accessor[m] < input_idx) { + l = m + 1; + } + else { + r = m - 1; + } + } + } + } + else { + grad_input_values = grad_values_expand; + } + return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options()); + } +} + }} // namespace at::native diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index eef8400..8113728 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -6,19 +6,27 @@ #include #include #include +#include "ATen/WrapDimUtilsMulti.h" #include #include + #include #include +#include +#include #include +#include + #define I_INFO(tensor) cuda::detail::getTensorInfo(tensor) #define V_INFO(tensor) cuda::detail::getTensorInfo(tensor) namespace at { namespace native { using namespace at::sparse; +using at::cuda::detail::TensorInfo; +using at::cuda::detail::getTensorInfo; // -------------------------------------------------------------------- // Utility functions @@ -466,4 +474,168 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons return r_._coalesced_(true); } +// -------------------------------------------------------------------- +// sparse.sum() backward +// +// see NOTE [ sparse.sum() backward ] +// -------------------------------------------------------------------- +template +__global__ void _sparse_sum_backward_cuda_kernel( + int64_t total_threads, + const TensorInfo grad_indices_ti, + const TensorInfo input_indices_ti, + const TensorInfo input_indices_pos_ti, + const TensorInfo grad_values_expand_ti, + TensorInfo grad_input_values_ti +) { + const int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= total_threads) return; + const int64_t j = input_indices_pos_ti.data[i]; + + bool has_match = false; + if (grad_indices_ti.data[j] == input_indices_ti.data[i]) { + has_match = true; + } + + int64_t grad_input_values_stride0 = grad_input_values_ti.strides[0]; + int64_t out_start = i * grad_input_values_stride0; + int64_t out_end = (i + 1) * grad_input_values_stride0; + int64_t in_start = j * grad_values_expand_ti.strides[0]; + + if (has_match) { + for (int64_t out_i = out_start, in_i = in_start; out_i < out_end; out_i++, in_i++) { + grad_input_values_ti.data[out_i] = grad_values_expand_ti.data[in_i]; + } + } + else { + for (int64_t out_i = out_start; out_i < out_end; out_i++) { + grad_input_values_ti.data[out_i] = scalar_t(0); + } + } +} + +Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) { + AT_CHECK(grad_.is_cuda(), "_sparse_sum_backward_cuda: expected 'grad_' to be CUDA tensor, but got CPU tensor"); + AT_CHECK(input_.is_cuda(), "_sparse_sum_backward_cuda: expected 'input_' to be CUDA tensor, but got CPU tensor"); + + auto input = input_.coalesce(); + const int64_t input_dim = input.dim(); + auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim); + auto dims_to_sum_v = dims_to_sum.vec(); + maybe_wrap_dims(dims_to_sum_v, input_dim); + + LongTensor input_indices = input._indices(); + Tensor input_values = input._values(); + IntList input_sizes = input.sizes(); + const int64_t input_sparse_dim = input.sparse_dim(); + const int64_t input_dense_dim = input.dense_dim(); + const int64_t input_nnz = input._nnz(); + + int64_t sparse_dims_to_sum_size = 0; + auto sparse_dims_to_keep_v = std::vector(); + auto dense_dims_to_sum_v = std::vector(); + for (int64_t d = 0; d < input_dim; d++) { + if (dims_to_sum_b[d]) { + if (d < input_sparse_dim) sparse_dims_to_sum_size ++; + else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim); + } + else { + if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d); + } + } + + const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size); + const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0); + const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0); + + if (sum_all_sparse_dim) { + AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad Tensor to be dense since all sparse dims are summed"); + auto grad_input_values = grad_; + auto expand_size = input_values.sizes().vec(); + if (sum_dense_dim) { + auto dense_expand_size = std::vector(expand_size); + dense_expand_size.erase(dense_expand_size.begin()); // remove nnz dim + for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim + grad_input_values = grad_input_values.expand(dense_expand_size); + } + grad_input_values = grad_input_values.expand(expand_size).clone(); + return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype + } + else { + AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad_ Tensor to be sparse, but got dense"); + auto grad = grad_.coalesce(); + LongTensor grad_indices = grad._indices(); + Tensor grad_values = grad._values(); + const int64_t grad_sparse_dim = grad.sparse_dim(); + const int64_t grad_nnz = grad._nnz(); + + Tensor grad_values_expand = grad_values; + if (sum_dense_dim) { + auto expand_size = input_values.sizes().vec(); + if (sum_sparse_dim) expand_size[0] = grad_values.size(0); // update nnz + for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d); + grad_values_expand = grad_values_expand.expand(expand_size).clone(); + } + + Tensor grad_input_values; + if (!sum_sparse_dim) { + grad_input_values = grad_values_expand; + } + else { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + typedef thrust::device_ptr thrust_ptr; + + grad_input_values = at::empty_like(input_values, grad_values.options()); + AT_ASSERT(grad_input_values.is_cuda()); + + // get 1D indices + auto grad_sparse_dim_to_keep_v = std::vector(grad_sparse_dim); + std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0); + + auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted + auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v); + thrust_ptr grad_indices_iter(grad_indices_1D.data()); + thrust_ptr input_indices_iter(input_indices_1D.data()); + + // store lower_bound of input indices at grad indices + LongTensor input_indices_pos = at::empty_like(input_indices_1D); + thrust_ptr input_indices_pos_iter(input_indices_pos.data()); + thrust::lower_bound(policy, + grad_indices_iter, grad_indices_iter + grad_nnz, + input_indices_iter, input_indices_iter + input_nnz, + input_indices_pos_iter); + + // config to run cuda kernel + int64_t total_threads = input_nnz; + const dim3 block = dim3(std::min(static_cast(cuda::getApplyBlock().x), total_threads)); + dim3 grid; + AT_CHECK(cuda::getApplyGrid(total_threads, grid, curDevice), "_sparse_sum_backward_cuda: input too large or too many dimensions"); + + auto grad_indices_ti = getTensorInfo(grad_indices_1D); + auto input_indices_ti = getTensorInfo(input_indices_1D); + auto input_indices_pos_ti = getTensorInfo(input_indices_pos); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.type(), "_sparse_sum_backward_cuda", [&] { + auto grad_values_expand_ti = getTensorInfo(grad_values_expand); + auto grad_input_values_ti = getTensorInfo(grad_input_values); + + _sparse_sum_backward_cuda_kernel<<>>( + total_threads, + grad_indices_ti, + input_indices_ti, + input_indices_pos_ti, + grad_values_expand_ti, + grad_input_values_ti + ); + }); + } + + return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options()); + } +} + }} // namespace at::native diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 4b8d5bf..1e6afde 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -60,6 +60,13 @@ An empty sparse tensor can be constructed by specifying its size: and values: [torch.FloatTensor with no dimension] +SparseTensor has the following invariants: + 1. sparse_dim + dense_dim = len(SparseTensor.shape) + 2. SparseTensor._indices().shape = (sparse_dim, nnz) + 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:]) +Since SparseTensor._indices() is always a 2D tensor, the smallest sparse_dim = 1. +Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense tensor. + .. note:: Our sparse tensor format permits *uncoalesced* sparse tensors, where @@ -134,3 +141,4 @@ Functions ---------------------------------- .. autofunction:: torch.sparse.addmm +.. autofunction:: torch.sparse.sum diff --git a/test/test_sparse.py b/test/test_sparse.py index 39c7e88..60fc784 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -935,6 +935,64 @@ class TestSparse(TestCase): test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0]) test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0]) + @skipIfRocm + def test_sparse_sum(self): + + def run_tests(S, td=None): + D = S.coalesce().to_dense().detach().requires_grad_(True) + mask = (D == 0) + if td is None: + S_sum = torch.sparse.sum(S) + D_sum = D.sum() + self.assertEqual(S_sum, D_sum) + S_sum.backward() + D_sum.backward() + D_grad = D.grad.masked_fill_(mask, 0) + self.assertEqual(S.grad.to_dense(), D_grad) + else: + S_sum = torch.sparse.sum(S, td) + D_sum = D.sum(td) + self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum) + S_sum.backward(S_sum.detach()) + S_grad = S.grad + data = S_sum.to_dense().detach() if S_sum.is_sparse else S_sum.detach() + D_sum.backward(data) + D_grad = D.grad.masked_fill_(mask, 0) + S_grad_dense = S_grad.coalesce().to_dense() if S_grad.is_sparse else S_grad + self.assertEqual(S_grad_dense, D_grad) + + nnz = 10 + sparse_dims = 2 + with_size = [5, 5, 1, 4] # use a dense dim = 1 to test for squeeze + test_dims = [] + for i in range(1, 5): + test_dims += itertools.combinations(range(len(with_size)), i) + + # not support SparseTensor.sum() + S = self._gen_sparse(sparse_dims, nnz, with_size)[0] + self.assertRaises(RuntimeError, lambda: S.sum()) + + # raise for incorrect input dims + self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, 5)) + self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0])) + + # sum an empty tensor + empty_S = torch.sparse_coo_tensor(size=with_size) + self.assertRaises(RuntimeError, lambda: torch.sparse.sum(empty_S, [0])) + self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0,)) + empty_S.requires_grad_(True) + empty_S_sum = torch.sparse.sum(empty_S) + empty_S_sum.backward() + self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense()) + + # test values().sum() + S = self._gen_sparse(sparse_dims, nnz, with_size)[0] + run_tests(S.requires_grad_(True)) + + for test_dim in test_dims: + S = self._gen_sparse(sparse_dims, nnz, with_size)[0] + run_tests(S.requires_grad_(True), test_dim) + def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None): shape = shape_i + (shape_v or []) x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 0736a46..30432a2 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -858,6 +858,9 @@ - name: _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, IntList size, Tensor indices, Tensor values, TensorOptions options) values: sparse_constructor_values_backward(grad, indices, values.sizes()) +- name: _sparse_sum(Tensor self, IntList dim) + self: at::_sparse_sum_backward(grad, self, dim) + - name: _standard_gamma(Tensor self, Generator generator) self: grad * _standard_gamma_grad(self, result) diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index e2ab34f..07553e9 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -3,12 +3,13 @@ import torch __all__ = [ 'addmm', + 'sum', ] def addmm(mat, mat1, mat2, beta=1, alpha=1): r""" - This function does exact same thing as :meth:`~Torch.addmm` in the forward, + This function does exact same thing as :func:`torch.addmm` in the forward, except that it supports backward for coalesced sparse matrix `mat1`. Args: @@ -19,3 +20,70 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1): alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) """ return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha) + + +def sum(input, dim=None, dtype=None): + r""" + Returns the sum of each row of SparseTensor :attr:`input` in the given + dimensions :attr:`dim`. If :attr::`dim` is a list of dimensions, + reduce over all of them. When sum over all ``sparse_dim``, this method + returns a Tensor instead of SparseTensor. + + All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output + tensor having :attr::`dim` fewer dimensions than :attr:`input`. + + During backward, only gradients at ``nnz`` locations of :attr:`input` + will propagate back. Note that the gradients of :attr:`input` is coalesced. + + Args: + input (Tensor): the input SparseTensor + dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce + over all dims. + dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. + Default: dtype of :attr:`input`. + + Example:: + + >>> nnz = 3 + >>> dims = [5, 5, 2, 3] + >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), + torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) + >>> V = torch.randn(nnz, dims[2], dims[3]) + >>> size = torch.Size(dims) + >>> S = torch.sparse_coo_tensor(I, V, size) + >>> S + tensor(indices=tensor([[2, 0, 3], + [2, 4, 1]]), + values=tensor([[[-0.6438, -1.6467, 1.4004], + [ 0.3411, 0.0918, -0.2312]], + + [[ 0.5348, 0.0634, -2.0494], + [-0.7125, -1.0646, 2.1844]], + + [[ 0.1276, 0.1874, -0.6334], + [-1.9682, -0.5340, 0.7483]]]), + size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) + + # when sum over only part of sparse_dims, return a SparseTensor + >>> torch.sparse.sum(S, [1, 3]) + tensor(indices=tensor([[0, 2, 3]]), + values=tensor([[-1.4512, 0.4073], + [-0.8901, 0.2017], + [-0.3183, -1.7539]]), + size=(5, 2), nnz=3, layout=torch.sparse_coo) + + # when sum over all sparse dim, return a dense Tensor + # with summed dims squeezed + >>> torch.sparse.sum(S, [0, 1, 3]) + tensor([-2.6596, -1.1450]) + """ + if dtype is None: + if dim: + return torch._sparse_sum(input, dim) + else: + return torch._sparse_sum(input) + else: + if dim: + return torch._sparse_sum(input, dim, dtype=dtype) + else: + return torch._sparse_sum(input, dtype=dtype) -- 2.7.4