Summary:
- to fix #12241
- add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently
- this PR depends on #11253, and will need to be updated upon it lands
- [x] implement forward
- [x] implement backward
- performance [benchmark script](https://gist.github.com/weiyangfb/
f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py):
- sum all dims is fastest for sparse tensor
- when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
- CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse
- CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims`
- optimize CUDA backward for now
- using thrust for sort and binary search, but runtime not improved
- both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA
- improve CPU and CUDA forward kernels
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs
- after improving CPU and CUDA forward kernels
- in `(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`.
- after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims`
(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs
(1000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms
(1000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs
(1000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs
(10000, [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms
(10000, [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs
(10000, [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs
- time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`:
```
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)),
torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12430
Differential Revision:
D12878313
Pulled By: weiyangfb
fbshipit-source-id:
e16dc7681ba41fdabf4838cf05e491ca9108c6fe
}
}
+// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
+// except this one allows partial flatten: only flatten on specified dims. Note that
+// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
+// Also if input indices is already coalesced, the flattened indices will also be sorted.
+//
+// args:
+// indices: sparse tensor indices
+// sizes: sparse tensor sizes
+// dims_to_flatten: a list of dim index to flatten
+//
+// Ex1:
+// indices = [[2, 4, 0],
+// [3, 1, 3]]
+// sizes = [2, 12]
+// dims_to_flatten = [0, 1]
+// new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
+//
+// Ex2:
+// dims_to_flatten = [1]
+// new_indices = [ 3, 1, 3 ] # uncoalesced
+inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){
+ LongTensor new_indices = at::zeros({indices.size(1)}, indices.options());
+ for (auto d : dims_to_flatten) {
+ new_indices.mul_(sizes[d]);
+ new_indices.add_(indices.select(0, d));
+ }
+ return new_indices;
+}
+
}} // namespace at::sparse
_(aten, _sparse_mul) \
_(aten, _sparse_mul_scalar) \
_(aten, _sparse_mul_zerodim) \
+_(aten, _sparse_sum) \
_(aten, _sqrt) \
_(aten, _standard_gamma) \
_(aten, _standard_gamma_grad) \
SparseCPU: norm_sparse
SparseCUDA: norm_sparse
+# TODO: reduce signatures down to one when optional args is available
+- func: _sparse_sum(Tensor self) -> Tensor
+
+- func: _sparse_sum(Tensor self, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum(Tensor self, IntList[1] dim) -> Tensor
+
+- func: _sparse_sum(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum_backward(Tensor grad, Tensor self, IntList dim) -> Tensor
+ dispatch:
+ SparseCPU: _sparse_sum_backward_cpu
+ SparseCUDA: _sparse_sum_backward_cuda
+
- func: norm(Tensor self, Scalar p=2) -> Tensor
variants: function, method
#include <ATen/NativeFunctions.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseTensorUtils.h>
+#include "ATen/WrapDimUtilsMulti.h"
#include <TH/THBlasUtils.h>
namespace at { namespace native {
using namespace at::sparse;
-
// --------------------------------------------------------------------
// Utility functions
// --------------------------------------------------------------------
return result;
}
+// --------------------------------------------------------------------
+// sparse.sum()
+//
+// This implementation calls coalesce() to do the sum reduction on
+// sparse dims. Ideally in the future there should be unified reduction function
+// for ops like sum, max, and min.
+// --------------------------------------------------------------------
+Tensor _sparse_sum(const SparseTensor& input) {
+ return input.coalesce().values().sum();
+}
+
+Tensor _sparse_sum(const SparseTensor& input, ScalarType dtype) {
+ // don't have to do a conversion to the correct dtype first
+ // just need to setup the accumulator correctly
+ return input.coalesce().values().sum(dtype);
+}
+
+Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum, ScalarType dtype) {
+ return at::_sparse_sum(input.to(dtype), dims_to_sum);
+}
+
+Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum) {
+ AT_CHECK(input._nnz() > 0, "_sparse_sum: sparse tensor input._nnz() == 0, please call torch.sparse.sum(input) instead.")
+
+ const int64_t input_dim = input.dim();
+ auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+ auto dims_to_sum_v = dims_to_sum.vec();
+ maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+ LongTensor indices = input._indices();
+ Tensor values = input._values();
+ IntList sizes = input.sizes();
+ const int64_t sparse_dim = input.sparse_dim();
+ const int64_t dense_dim = input.dense_dim();
+
+ auto dims_to_keep_v = std::vector<int64_t>();
+ auto dense_dims_to_sum_v = std::vector<int64_t>();
+ for (int64_t d = 0; d < input_dim; d++) {
+ if (dims_to_sum_b[d]) {
+ if (d >= sparse_dim) dense_dims_to_sum_v.emplace_back(d + 1 - sparse_dim);
+ }
+ else {
+ dims_to_keep_v.emplace_back(d);
+ }
+ }
+ const int64_t sparse_dims_to_sum_size = dims_to_sum_v.size() - dense_dims_to_sum_v.size();
+ const bool sum_all_sparse_dim = (sparse_dim == sparse_dims_to_sum_size);
+ const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+
+ // new values
+ Tensor new_values;
+ if (sum_dense_dim) {
+ new_values = values.sum(dense_dims_to_sum_v);
+ }
+ else {
+ new_values = values.clone();
+ }
+
+ if (sum_all_sparse_dim) {
+ // return a dense tensor if sum over all sparse dims
+ new_values = new_values.sum(0);
+ return new_values;
+ }
+ else { // !sum_all_sparse_dim
+ // new indices
+ LongTensor new_indices;
+ if (sparse_dims_to_sum_size == 0) {
+ new_indices = indices.clone();
+ }
+ else {
+ new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options());
+ for (int64_t i = 0; i < dims_to_keep_v.size(); i++) {
+ int64_t d = dims_to_keep_v[i];
+ if (d < sparse_dim) new_indices[i].copy_(indices[d]);
+ else break;
+ }
+ }
+
+ // new size
+ int64_t new_sparse_dim = new_indices.size(0);
+ int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim
+ std::vector<int64_t> new_sizes;
+ for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]);
+ if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);
+
+ // use coalesce() to do sum reduction
+ SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options());
+ new_sparse = new_sparse.coalesce();
+ return new_sparse;
+ }
+
+}
+
+// --------------------------------------------------------------------
+// NOTE [ sparse.sum() backward ]
+//
+// When sum over sparse_dim, backward scatters gradients from grad tensor to input tensor.
+// Grad and input need to align indices over sparse_dim that are not summed (given
+// input.spares_dim >= grad.sparse_dim). Implementation here compares each pair of
+// indices between grad and input. When a matching indices pair (input_i, grad_i) is found,
+// copy grad.values[grad_i] -> input_grad.values[input_i]. E.g.,
+//
+// input.sparse_dim = [5, 5]
+// input.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
+// [1, 4, 4, 0, 1, 3, 2, 4]]
+// input.values = [0, 1, 2, 3, 4, 5, 6, 7]
+// ...
+// sparse.sum(input, [0])
+// backward(...)
+// ...
+// grad.indices = [[0, 1, 2, 3]]
+// grad.values = [1, 2, 0, 4]
+//
+// # after indices matching
+// input grad
+// [[0, 1], -> [1]
+// [0, 4], -> [ ]
+// [1, 4], -> [ ]
+// [2, 0], -> [0]
+// [2, 1], -> [1]
+// [3, 3], -> [3]
+// [4, 2], -> [2]
+// [4, 4]]) -> [ ]
+//
+// input_grad.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
+// [1, 4, 4, 0, 1, 3, 2, 4]]
+// input_grad.values = [2, 0, 0, 1, 2, 4, 0, 0]
+//
+// Note that we allow input to be uncoalesced in the forward,
+// we have to coalesce input at the backward, because grad-of-input
+// take the same indices as input, if input is not coalesced, then
+// coalescing grad-of-input may add up grad values for a duplicate indices,
+// and hence generates a wrong grad-of-input.
+//
+// Other edge cases:
+// - assign zero values to input gradients if cannot find matched indices at grad
+// - grad.values might have zeros
+// --------------------------------------------------------------------
+Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) {
+ AT_CHECK(!grad_.is_cuda(), "_sparse_sum_backward_cpu: expected 'grad_' to be CPU tensor, but got CUDA tensor");
+ AT_CHECK(!input_.is_cuda(), "_sparse_sum_backward_cpu: expected 'input_' to be CPU tensor, but got CUDA tensor");
+
+ auto input = input_.coalesce();
+ const int64_t input_dim = input.dim();
+ auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+ auto dims_to_sum_v = dims_to_sum.vec();
+ maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+ LongTensor input_indices = input._indices();
+ Tensor input_values = input._values();
+ IntList input_sizes = input.sizes();
+ const int64_t input_sparse_dim = input.sparse_dim();
+ const int64_t input_dense_dim = input.dense_dim();
+ const int64_t input_nnz = input._nnz();
+
+ int64_t sparse_dims_to_sum_size = 0;
+ auto sparse_dims_to_keep_v = std::vector<int64_t>();
+ auto dense_dims_to_sum_v = std::vector<int64_t>();
+ for (int64_t d = 0; d < input_dim; d++) {
+ if (dims_to_sum_b[d]) {
+ if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
+ else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
+ }
+ else {
+ if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
+ }
+ }
+
+ const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
+ const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+ const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
+
+ if (sum_all_sparse_dim) {
+ AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be dense since all sparse dims are summed");
+ auto grad_input_values = grad_;
+ auto expand_size = input_values.sizes().vec();
+ if (sum_dense_dim) {
+ auto dense_expand_size = std::vector<int64_t>(expand_size);
+ dense_expand_size.erase(dense_expand_size.begin());
+ AT_ASSERT(dense_expand_size.size() == (input_values.dim() - 1));
+ for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim
+ grad_input_values = grad_input_values.expand(dense_expand_size);
+ }
+ grad_input_values = grad_input_values.expand(expand_size).clone();
+ return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype
+ }
+ else {
+ AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
+ auto grad = grad_.coalesce();
+ LongTensor grad_indices = grad._indices();
+ Tensor grad_values = grad._values();
+ const int64_t grad_sparse_dim = grad.sparse_dim();
+ const int64_t grad_nnz = grad._nnz();
+
+ Tensor grad_values_expand = grad_values;
+ if (sum_dense_dim) {
+ auto expand_size = input_values.sizes().vec();
+ if (sum_sparse_dim) expand_size[0] = grad_values.size(0);
+ for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
+ grad_values_expand = grad_values_expand.expand(expand_size).clone();
+ }
+
+ Tensor grad_input_values;
+ if (sum_sparse_dim) {
+ // see NOTE [ sparse.sum() backward ]
+ grad_input_values = at::zeros_like(input_values, grad_values.options());
+
+ // get flatten indices for grad and input
+ auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
+ std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
+
+ auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
+ auto grad_indices_1D_accessor = grad_indices_1D.accessor<int64_t, 1>();
+ auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
+ auto input_indices_1D_accessor = input_indices_1D.accessor<int64_t, 1>();
+
+ // binary search to find matching indices
+ int64_t i;
+ #pragma omp parallel for private(i)
+ for (i = 0; i < input_nnz; i++) {
+ int64_t input_idx = input_indices_1D_accessor[i];
+ int64_t l = 0, r = grad_nnz - 1;
+ while (l <= r) {
+ int64_t m = l + (r - l) / 2;
+ if (grad_indices_1D_accessor[m] == input_idx) {
+ grad_input_values[i].copy_(grad_values_expand[m]);
+ break;
+ }
+ if (grad_indices_1D_accessor[m] < input_idx) {
+ l = m + 1;
+ }
+ else {
+ r = m - 1;
+ }
+ }
+ }
+ }
+ else {
+ grad_input_values = grad_values_expand;
+ }
+ return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options());
+ }
+}
+
}} // namespace at::native
#include <ATen/native/sparse/cuda/SparseCUDABlas.cuh>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
+#include "ATen/WrapDimUtilsMulti.h"
#include <THC/THCTensorMathPointwise.cuh>
#include <THC/THCThrustAllocator.cuh>
+
#include <thrust/device_ptr.h>
#include <thrust/sequence.h>
+#include <thrust/binary_search.h>
+#include <thrust/sort.h>
#include <thrust/system/cuda/execution_policy.h>
+#include <bitset>
+
#define I_INFO(tensor) cuda::detail::getTensorInfo<int64_t, uint64_t>(tensor)
#define V_INFO(tensor) cuda::detail::getTensorInfo<scalar_t, uint64_t>(tensor)
namespace at { namespace native {
using namespace at::sparse;
+using at::cuda::detail::TensorInfo;
+using at::cuda::detail::getTensorInfo;
// --------------------------------------------------------------------
// Utility functions
return r_._coalesced_(true);
}
+// --------------------------------------------------------------------
+// sparse.sum() backward
+//
+// see NOTE [ sparse.sum() backward ]
+// --------------------------------------------------------------------
+template <typename scalar_t>
+__global__ void _sparse_sum_backward_cuda_kernel(
+ int64_t total_threads,
+ const TensorInfo<int64_t, int64_t> grad_indices_ti,
+ const TensorInfo<int64_t, int64_t> input_indices_ti,
+ const TensorInfo<int64_t, int64_t> input_indices_pos_ti,
+ const TensorInfo<scalar_t, int64_t> grad_values_expand_ti,
+ TensorInfo<scalar_t, int64_t> grad_input_values_ti
+) {
+ const int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i >= total_threads) return;
+ const int64_t j = input_indices_pos_ti.data[i];
+
+ bool has_match = false;
+ if (grad_indices_ti.data[j] == input_indices_ti.data[i]) {
+ has_match = true;
+ }
+
+ int64_t grad_input_values_stride0 = grad_input_values_ti.strides[0];
+ int64_t out_start = i * grad_input_values_stride0;
+ int64_t out_end = (i + 1) * grad_input_values_stride0;
+ int64_t in_start = j * grad_values_expand_ti.strides[0];
+
+ if (has_match) {
+ for (int64_t out_i = out_start, in_i = in_start; out_i < out_end; out_i++, in_i++) {
+ grad_input_values_ti.data[out_i] = grad_values_expand_ti.data[in_i];
+ }
+ }
+ else {
+ for (int64_t out_i = out_start; out_i < out_end; out_i++) {
+ grad_input_values_ti.data[out_i] = scalar_t(0);
+ }
+ }
+}
+
+Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) {
+ AT_CHECK(grad_.is_cuda(), "_sparse_sum_backward_cuda: expected 'grad_' to be CUDA tensor, but got CPU tensor");
+ AT_CHECK(input_.is_cuda(), "_sparse_sum_backward_cuda: expected 'input_' to be CUDA tensor, but got CPU tensor");
+
+ auto input = input_.coalesce();
+ const int64_t input_dim = input.dim();
+ auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+ auto dims_to_sum_v = dims_to_sum.vec();
+ maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+ LongTensor input_indices = input._indices();
+ Tensor input_values = input._values();
+ IntList input_sizes = input.sizes();
+ const int64_t input_sparse_dim = input.sparse_dim();
+ const int64_t input_dense_dim = input.dense_dim();
+ const int64_t input_nnz = input._nnz();
+
+ int64_t sparse_dims_to_sum_size = 0;
+ auto sparse_dims_to_keep_v = std::vector<int64_t>();
+ auto dense_dims_to_sum_v = std::vector<int64_t>();
+ for (int64_t d = 0; d < input_dim; d++) {
+ if (dims_to_sum_b[d]) {
+ if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
+ else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
+ }
+ else {
+ if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
+ }
+ }
+
+ const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
+ const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+ const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
+
+ if (sum_all_sparse_dim) {
+ AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad Tensor to be dense since all sparse dims are summed");
+ auto grad_input_values = grad_;
+ auto expand_size = input_values.sizes().vec();
+ if (sum_dense_dim) {
+ auto dense_expand_size = std::vector<int64_t>(expand_size);
+ dense_expand_size.erase(dense_expand_size.begin()); // remove nnz dim
+ for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim
+ grad_input_values = grad_input_values.expand(dense_expand_size);
+ }
+ grad_input_values = grad_input_values.expand(expand_size).clone();
+ return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype
+ }
+ else {
+ AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad_ Tensor to be sparse, but got dense");
+ auto grad = grad_.coalesce();
+ LongTensor grad_indices = grad._indices();
+ Tensor grad_values = grad._values();
+ const int64_t grad_sparse_dim = grad.sparse_dim();
+ const int64_t grad_nnz = grad._nnz();
+
+ Tensor grad_values_expand = grad_values;
+ if (sum_dense_dim) {
+ auto expand_size = input_values.sizes().vec();
+ if (sum_sparse_dim) expand_size[0] = grad_values.size(0); // update nnz
+ for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
+ grad_values_expand = grad_values_expand.expand(expand_size).clone();
+ }
+
+ Tensor grad_input_values;
+ if (!sum_sparse_dim) {
+ grad_input_values = grad_values_expand;
+ }
+ else {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+ auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+ auto policy = thrust::cuda::par(allocator).on(stream);
+ typedef thrust::device_ptr<int64_t> thrust_ptr;
+
+ grad_input_values = at::empty_like(input_values, grad_values.options());
+ AT_ASSERT(grad_input_values.is_cuda());
+
+ // get 1D indices
+ auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
+ std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
+
+ auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
+ auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
+ thrust_ptr grad_indices_iter(grad_indices_1D.data<int64_t>());
+ thrust_ptr input_indices_iter(input_indices_1D.data<int64_t>());
+
+ // store lower_bound of input indices at grad indices
+ LongTensor input_indices_pos = at::empty_like(input_indices_1D);
+ thrust_ptr input_indices_pos_iter(input_indices_pos.data<int64_t>());
+ thrust::lower_bound(policy,
+ grad_indices_iter, grad_indices_iter + grad_nnz,
+ input_indices_iter, input_indices_iter + input_nnz,
+ input_indices_pos_iter);
+
+ // config to run cuda kernel
+ int64_t total_threads = input_nnz;
+ const dim3 block = dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), total_threads));
+ dim3 grid;
+ AT_CHECK(cuda::getApplyGrid(total_threads, grid, curDevice), "_sparse_sum_backward_cuda: input too large or too many dimensions");
+
+ auto grad_indices_ti = getTensorInfo<int64_t, int64_t>(grad_indices_1D);
+ auto input_indices_ti = getTensorInfo<int64_t, int64_t>(input_indices_1D);
+ auto input_indices_pos_ti = getTensorInfo<int64_t, int64_t>(input_indices_pos);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.type(), "_sparse_sum_backward_cuda", [&] {
+ auto grad_values_expand_ti = getTensorInfo<scalar_t, int64_t>(grad_values_expand);
+ auto grad_input_values_ti = getTensorInfo<scalar_t, int64_t>(grad_input_values);
+
+ _sparse_sum_backward_cuda_kernel<scalar_t><<<grid, block, 0, stream>>>(
+ total_threads,
+ grad_indices_ti,
+ input_indices_ti,
+ input_indices_pos_ti,
+ grad_values_expand_ti,
+ grad_input_values_ti
+ );
+ });
+ }
+
+ return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options());
+ }
+}
+
}} // namespace at::native
and values:
[torch.FloatTensor with no dimension]
+SparseTensor has the following invariants:
+ 1. sparse_dim + dense_dim = len(SparseTensor.shape)
+ 2. SparseTensor._indices().shape = (sparse_dim, nnz)
+ 3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:])
+Since SparseTensor._indices() is always a 2D tensor, the smallest sparse_dim = 1.
+Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense tensor.
+
.. note::
Our sparse tensor format permits *uncoalesced* sparse tensors, where
----------------------------------
.. autofunction:: torch.sparse.addmm
+.. autofunction:: torch.sparse.sum
test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0])
test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0])
+ @skipIfRocm
+ def test_sparse_sum(self):
+
+ def run_tests(S, td=None):
+ D = S.coalesce().to_dense().detach().requires_grad_(True)
+ mask = (D == 0)
+ if td is None:
+ S_sum = torch.sparse.sum(S)
+ D_sum = D.sum()
+ self.assertEqual(S_sum, D_sum)
+ S_sum.backward()
+ D_sum.backward()
+ D_grad = D.grad.masked_fill_(mask, 0)
+ self.assertEqual(S.grad.to_dense(), D_grad)
+ else:
+ S_sum = torch.sparse.sum(S, td)
+ D_sum = D.sum(td)
+ self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum)
+ S_sum.backward(S_sum.detach())
+ S_grad = S.grad
+ data = S_sum.to_dense().detach() if S_sum.is_sparse else S_sum.detach()
+ D_sum.backward(data)
+ D_grad = D.grad.masked_fill_(mask, 0)
+ S_grad_dense = S_grad.coalesce().to_dense() if S_grad.is_sparse else S_grad
+ self.assertEqual(S_grad_dense, D_grad)
+
+ nnz = 10
+ sparse_dims = 2
+ with_size = [5, 5, 1, 4] # use a dense dim = 1 to test for squeeze
+ test_dims = []
+ for i in range(1, 5):
+ test_dims += itertools.combinations(range(len(with_size)), i)
+
+ # not support SparseTensor.sum()
+ S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+ self.assertRaises(RuntimeError, lambda: S.sum())
+
+ # raise for incorrect input dims
+ self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, 5))
+ self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0]))
+
+ # sum an empty tensor
+ empty_S = torch.sparse_coo_tensor(size=with_size)
+ self.assertRaises(RuntimeError, lambda: torch.sparse.sum(empty_S, [0]))
+ self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0,))
+ empty_S.requires_grad_(True)
+ empty_S_sum = torch.sparse.sum(empty_S)
+ empty_S_sum.backward()
+ self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense())
+
+ # test values().sum()
+ S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+ run_tests(S.requires_grad_(True))
+
+ for test_dim in test_dims:
+ S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+ run_tests(S.requires_grad_(True), test_dim)
+
def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None):
shape = shape_i + (shape_v or [])
x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape)
- name: _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, IntList size, Tensor indices, Tensor values, TensorOptions options)
values: sparse_constructor_values_backward(grad, indices, values.sizes())
+- name: _sparse_sum(Tensor self, IntList dim)
+ self: at::_sparse_sum_backward(grad, self, dim)
+
- name: _standard_gamma(Tensor self, Generator generator)
self: grad * _standard_gamma_grad(self, result)
__all__ = [
'addmm',
+ 'sum',
]
def addmm(mat, mat1, mat2, beta=1, alpha=1):
r"""
- This function does exact same thing as :meth:`~Torch.addmm` in the forward,
+ This function does exact same thing as :func:`torch.addmm` in the forward,
except that it supports backward for coalesced sparse matrix `mat1`.
Args:
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
"""
return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
+
+
+def sum(input, dim=None, dtype=None):
+ r"""
+ Returns the sum of each row of SparseTensor :attr:`input` in the given
+ dimensions :attr:`dim`. If :attr::`dim` is a list of dimensions,
+ reduce over all of them. When sum over all ``sparse_dim``, this method
+ returns a Tensor instead of SparseTensor.
+
+ All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
+ tensor having :attr::`dim` fewer dimensions than :attr:`input`.
+
+ During backward, only gradients at ``nnz`` locations of :attr:`input`
+ will propagate back. Note that the gradients of :attr:`input` is coalesced.
+
+ Args:
+ input (Tensor): the input SparseTensor
+ dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
+ over all dims.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
+ Default: dtype of :attr:`input`.
+
+ Example::
+
+ >>> nnz = 3
+ >>> dims = [5, 5, 2, 3]
+ >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
+ torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
+ >>> V = torch.randn(nnz, dims[2], dims[3])
+ >>> size = torch.Size(dims)
+ >>> S = torch.sparse_coo_tensor(I, V, size)
+ >>> S
+ tensor(indices=tensor([[2, 0, 3],
+ [2, 4, 1]]),
+ values=tensor([[[-0.6438, -1.6467, 1.4004],
+ [ 0.3411, 0.0918, -0.2312]],
+
+ [[ 0.5348, 0.0634, -2.0494],
+ [-0.7125, -1.0646, 2.1844]],
+
+ [[ 0.1276, 0.1874, -0.6334],
+ [-1.9682, -0.5340, 0.7483]]]),
+ size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
+
+ # when sum over only part of sparse_dims, return a SparseTensor
+ >>> torch.sparse.sum(S, [1, 3])
+ tensor(indices=tensor([[0, 2, 3]]),
+ values=tensor([[-1.4512, 0.4073],
+ [-0.8901, 0.2017],
+ [-0.3183, -1.7539]]),
+ size=(5, 2), nnz=3, layout=torch.sparse_coo)
+
+ # when sum over all sparse dim, return a dense Tensor
+ # with summed dims squeezed
+ >>> torch.sparse.sum(S, [0, 1, 3])
+ tensor([-2.6596, -1.1450])
+ """
+ if dtype is None:
+ if dim:
+ return torch._sparse_sum(input, dim)
+ else:
+ return torch._sparse_sum(input)
+ else:
+ if dim:
+ return torch._sparse_sum(input, dim, dtype=dtype)
+ else:
+ return torch._sparse_sum(input, dtype=dtype)