torch.sparse.sum() (#12430)
authorWei Yang <weiyang@fb.com>
Wed, 28 Nov 2018 10:16:56 +0000 (02:16 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 28 Nov 2018 10:19:12 +0000 (02:19 -0800)
Summary:
- to fix #12241
- add `_sparse_sum()` to ATen, and expose as `torch.sparse.sum()`, not support `SparseTensor.sum()` currently
- this PR depends on #11253, and will need to be updated upon it lands
- [x] implement forward
- [x] implement backward
- performance [benchmark script](https://gist.github.com/weiyangfb/f4c55c88b6092ef8f7e348f6b9ad8946#file-sparse_sum_benchmark-py):
  - sum all dims is fastest for sparse tensor
  - when input is sparse enough nnz = 0.1%, sum of sparse tensor is faster than dense in CPU, but not necessary in CUDA
  - CUDA backward is comparable (<2x) between `sum several dims` vs `sum all dims` in sparse
  - CPU backward uses binary search is still slow in sparse, takes `5x` time in `sum [0, 2, 3] dims` vs `sum all dims`
    - optimize CUDA backward for now
      - using thrust for sort and binary search, but runtime not improved
  - both of CPU and CUDA forward are slow in sparse (`sum several dims` vs `sum all dims`), at most `20x` slower in CPU, and `10x` in CUDA
    - improve CPU and CUDA forward kernels

(nnz, sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 8.77 µs vs 72.9 µs | 42.5 µs vs 108 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 112 µs vs 4.47 ms | 484 µs vs 407 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 141 µs vs 148 µs | 647 µs vs 231 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 235 µs vs 1.23 ms | 781 µs vs 213 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 48.5 µs vs 360 µs | 160 µs vs 2.03 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 258 µs vs 1.22 ms | 798 µs vs 224 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 204 µs vs 882 µs | 443 µs vs 133 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 709 µs vs 1.15 ms | 893 µs vs 202 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 39.8 µs vs 81 µs | 42.4 µs vs 113 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 747 µs vs 4.7 ms | 2.4 ms vs 414 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 1.04 ms vs 126 µs | 5.03 ms vs 231 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.12 ms vs 1.24 ms | 5.99 ms vs 213 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 133 µs vs 366 µs | 463 µs vs 2.03 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.56 ms vs 1.22 ms | 6.11 ms vs 229 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.53 ms vs 799 µs | 824 µs vs 134 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 5.15 ms vs 1.09 ms | 7.02 ms vs 205 µs

- after improving CPU and CUDA forward kernels
  - in `(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD)` forward, CPU takes ~~`171 µs`~~, in which `130 µs` is spent on `coalesce()`, for CUDA, total time is ~~`331 µs`~~, in which `141 µs` is spent on `coalesce()`, we need to reduce time at other places outside `coalesce()`.
  - after a few simple tweaks, now in the forward, it is at most `10x` slower in CPU, and `7x` in CUDA. And time takes in `sum dense dims only [2, 3]` is `~2x` of `sum all dims`. Speed of `sum all sparse dims [0, 1]` is on bar with `sum all dims`

(nnz,   sizes, sum_dims, keepdim, sum all or dims, bk=backward) | CPU (sparse vs dense) | CUDA(sparse vs dense)
-- | -- | --
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 7 µs vs 69.5 µs | 31.5 µs vs 61.6 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 11.3 µs vs 4.72 ms | 35.2 µs vs 285 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 197 µs vs 124 µs | 857 µs vs 134 µs
(1000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 124 µs vs 833 µs | 796 µs vs 106 µs
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 20.5 µs vs 213 µs | 39.4 µs vs 1.24 ms
(1000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 131 µs vs 830 µs | 881 µs vs 132 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 95.8 µs vs 409 µs | 246 µs vs 87.2 µs
(1000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 624 µs vs 820 µs | 953 µs vs 124 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll) | 45.3 µs vs 72.9 µs | 33.9 µs vs 57.2 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD) | 81.4 µs vs 4.49 ms | 39.7 µs vs 280 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumAll, bk) | 984 µs vs 111 µs | 6.41 ms vs 121 µs
(10000,   [1000, 1000, 2, 2], [0, 1], False, sumD, bk) | 1.45 ms vs 828 µs | 6.77 ms vs 113 µs
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD) | 74.9 µs vs 209 µs | 37.7 µs vs 1.23 ms
(10000,   [1000, 1000, 2, 2], [2, 3], False, sumD, bk) | 1.48 ms vs 845 µs | 6.96 ms vs 132 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD) | 1.14 ms vs 411 µs | 252 µs vs 87.8 µs
(10000,   [1000, 1000, 2, 2], [0, 2, 3], False, sumD, bk) | 4.53 ms vs 851 µs | 7.12 ms vs 128 µs

- time takes in CUDA backward of sparse is super long with large variance (in case of nnz=10000, it normally takes 6-7ms). To improve backward of sparse ops, we will need to debug at places other than CUDA kernels. here is a benchmark of `torch.copy_()`:
```
>>> d = [1000, 1000, 2, 2]
>>> nnz = 10000
>>> I = torch.cat([torch.randint(0, d[0], size=(nnz,)),
               torch.randint(0, d[1], size=(nnz,))], 0).reshape(2, nnz)
>>> V = torch.randn(nnz, d[2], d[3])
>>> size = torch.Size(d)
>>> S = torch.sparse_coo_tensor(I, V, size).coalesce().cuda()
>>> S2 = torch.sparse_coo_tensor(I, V, size).coalesce().cuda().requires_grad_()
>>> data = S2.clone()
>>> S.copy_(S2)
>>> y = S * 2
>>> torch.cuda.synchronize()
>>> %timeit y.backward(data, retain_graph=True); torch.cuda.synchronize()
7.07 ms ± 3.06 ms per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12430

Differential Revision: D12878313

Pulled By: weiyangfb

fbshipit-source-id: e16dc7681ba41fdabf4838cf05e491ca9108c6fe

aten/src/ATen/SparseTensorUtils.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensorMath.cpp
aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
docs/source/sparse.rst
test/test_sparse.py
tools/autograd/derivatives.yaml
torch/sparse/__init__.py

index 2eed30b..2f12c80 100644 (file)
@@ -109,4 +109,33 @@ inline LongTensor flatten_indices(const Tensor& indices, IntList full_size, bool
   }
 }
 
+// Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
+// except this one allows partial flatten: only flatten on specified dims. Note that
+// the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
+// Also if input indices is already coalesced, the flattened indices will also be sorted.
+//
+// args:
+//    indices: sparse tensor indices
+//    sizes: sparse tensor sizes
+//    dims_to_flatten: a list of dim index to flatten
+//
+// Ex1:
+//   indices = [[2, 4, 0],
+//             [3, 1, 3]]
+//   sizes = [2, 12]
+//   dims_to_flatten = [0, 1]
+//   new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
+//
+// Ex2:
+//   dims_to_flatten = [1]
+//   new_indices = [ 3, 1, 3 ]  # uncoalesced
+inline LongTensor flatten_indices_by_dims(const LongTensor& indices, const IntList& sizes, const IntList& dims_to_flatten){
+  LongTensor new_indices = at::zeros({indices.size(1)}, indices.options());
+  for (auto d : dims_to_flatten) {
+    new_indices.mul_(sizes[d]);
+    new_indices.add_(indices.select(0, d));
+  }
+  return new_indices;
+}
+
 }} // namespace at::sparse
index 5f6896d..486d5cf 100644 (file)
@@ -128,6 +128,7 @@ _(aten, _sparse_div_zerodim) \
 _(aten, _sparse_mul) \
 _(aten, _sparse_mul_scalar) \
 _(aten, _sparse_mul_zerodim) \
+_(aten, _sparse_sum) \
 _(aten, _sqrt) \
 _(aten, _standard_gamma) \
 _(aten, _standard_gamma_grad) \
index fe51e37..7e71f47 100644 (file)
     SparseCPU: norm_sparse
     SparseCUDA: norm_sparse
 
+# TODO: reduce signatures down to one when optional args is available
+- func: _sparse_sum(Tensor self) -> Tensor
+
+- func: _sparse_sum(Tensor self, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum(Tensor self, IntList[1] dim) -> Tensor
+
+- func: _sparse_sum(Tensor self, IntList[1] dim, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum_backward(Tensor grad, Tensor self, IntList dim) -> Tensor
+  dispatch:
+      SparseCPU: _sparse_sum_backward_cpu
+      SparseCUDA: _sparse_sum_backward_cuda
+
 - func: norm(Tensor self, Scalar p=2) -> Tensor
   variants: function, method
 
index ee98ce2..414e4d4 100644 (file)
@@ -4,13 +4,13 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/InitialTensorOptions.h>
 #include <ATen/SparseTensorUtils.h>
+#include "ATen/WrapDimUtilsMulti.h"
 
 #include <TH/THBlasUtils.h>
 
 namespace at { namespace native {
 
 using namespace at::sparse;
-
 // --------------------------------------------------------------------
 // Utility functions
 // --------------------------------------------------------------------
@@ -814,4 +814,248 @@ Tensor sspaddmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2,
   return result;
 }
 
+// --------------------------------------------------------------------
+// sparse.sum()
+//
+// This implementation calls coalesce() to do the sum reduction on
+// sparse dims. Ideally in the future there should be unified reduction function
+// for ops like sum, max, and min.
+// --------------------------------------------------------------------
+Tensor _sparse_sum(const SparseTensor& input) {
+  return input.coalesce().values().sum();
+}
+
+Tensor _sparse_sum(const SparseTensor& input, ScalarType dtype) {
+  // don't have to do a conversion to the correct dtype first
+  // just need to setup the accumulator correctly
+  return input.coalesce().values().sum(dtype);
+}
+
+Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum, ScalarType dtype) {
+  return at::_sparse_sum(input.to(dtype), dims_to_sum);
+}
+
+Tensor _sparse_sum(const SparseTensor& input, IntList dims_to_sum) {
+  AT_CHECK(input._nnz() > 0, "_sparse_sum: sparse tensor input._nnz() == 0, please call torch.sparse.sum(input) instead.")
+
+  const int64_t input_dim = input.dim();
+  auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+  auto dims_to_sum_v = dims_to_sum.vec();
+  maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+  LongTensor indices = input._indices();
+  Tensor values = input._values();
+  IntList sizes = input.sizes();
+  const int64_t sparse_dim = input.sparse_dim();
+  const int64_t dense_dim = input.dense_dim();
+
+  auto dims_to_keep_v = std::vector<int64_t>();
+  auto dense_dims_to_sum_v = std::vector<int64_t>();
+  for (int64_t d = 0; d < input_dim; d++) {
+    if (dims_to_sum_b[d]) {
+      if (d >= sparse_dim) dense_dims_to_sum_v.emplace_back(d + 1 - sparse_dim);
+    }
+    else {
+      dims_to_keep_v.emplace_back(d);
+    }
+  }
+  const int64_t sparse_dims_to_sum_size = dims_to_sum_v.size() - dense_dims_to_sum_v.size();
+  const bool sum_all_sparse_dim = (sparse_dim == sparse_dims_to_sum_size);
+  const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+
+  // new values
+  Tensor new_values;
+  if (sum_dense_dim) {
+    new_values = values.sum(dense_dims_to_sum_v);
+  }
+  else {
+    new_values = values.clone();
+  }
+
+  if (sum_all_sparse_dim) {
+    // return a dense tensor if sum over all sparse dims
+    new_values = new_values.sum(0);
+    return new_values;
+  }
+  else { // !sum_all_sparse_dim
+    // new indices
+    LongTensor new_indices;
+    if (sparse_dims_to_sum_size == 0) {
+      new_indices = indices.clone();
+    }
+    else {
+      new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options());
+      for (int64_t i = 0; i < dims_to_keep_v.size(); i++) {
+        int64_t d = dims_to_keep_v[i];
+        if (d < sparse_dim) new_indices[i].copy_(indices[d]);
+        else break;
+      }
+    }
+
+    // new size
+    int64_t new_sparse_dim = new_indices.size(0);
+    int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim
+    std::vector<int64_t> new_sizes;
+    for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]);
+    if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);
+
+    // use coalesce() to do sum reduction
+    SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options());
+    new_sparse = new_sparse.coalesce();
+    return new_sparse;
+  }
+
+}
+
+// --------------------------------------------------------------------
+// NOTE [ sparse.sum() backward ]
+//
+// When sum over sparse_dim, backward scatters gradients from grad tensor to input tensor.
+// Grad and input need to align indices over sparse_dim that are not summed (given
+// input.spares_dim >= grad.sparse_dim). Implementation here compares each pair of
+// indices between grad and input. When a matching indices pair (input_i, grad_i) is found,
+// copy grad.values[grad_i] -> input_grad.values[input_i]. E.g.,
+//
+//  input.sparse_dim = [5, 5]
+//  input.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
+//                   [1, 4, 4, 0, 1, 3, 2, 4]]
+//  input.values =   [0, 1, 2, 3, 4, 5, 6, 7]
+//  ...
+//  sparse.sum(input, [0])
+//  backward(...)
+//  ...
+//  grad.indices = [[0, 1, 2, 3]]
+//  grad.values =   [1, 2, 0, 4]
+//
+// # after indices matching
+//         input         grad
+//        [[0, 1],   ->  [1]
+//         [0, 4],   ->  [ ]
+//         [1, 4],   ->  [ ]
+//         [2, 0],   ->  [0]
+//         [2, 1],   ->  [1]
+//         [3, 3],   ->  [3]
+//         [4, 2],   ->  [2]
+//         [4, 4]])  ->  [ ]
+//
+// input_grad.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
+//                       [1, 4, 4, 0, 1, 3, 2, 4]]
+// input_grad.values =   [2, 0, 0, 1, 2, 4, 0, 0]
+//
+// Note that we allow input to be uncoalesced in the forward,
+// we have to coalesce input at the backward, because grad-of-input
+// take the same indices as input, if input is not coalesced, then
+// coalescing grad-of-input may add up grad values for a duplicate indices,
+// and hence generates a wrong grad-of-input.
+//
+// Other edge cases:
+// - assign zero values to input gradients if cannot find matched indices at grad
+// - grad.values might have zeros
+// --------------------------------------------------------------------
+Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) {
+  AT_CHECK(!grad_.is_cuda(), "_sparse_sum_backward_cpu: expected 'grad_' to be CPU tensor, but got CUDA tensor");
+  AT_CHECK(!input_.is_cuda(), "_sparse_sum_backward_cpu: expected 'input_' to be CPU tensor, but got CUDA tensor");
+
+  auto input = input_.coalesce();
+  const int64_t input_dim = input.dim();
+  auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+  auto dims_to_sum_v = dims_to_sum.vec();
+  maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+  LongTensor input_indices = input._indices();
+  Tensor input_values = input._values();
+  IntList input_sizes = input.sizes();
+  const int64_t input_sparse_dim = input.sparse_dim();
+  const int64_t input_dense_dim = input.dense_dim();
+  const int64_t input_nnz = input._nnz();
+
+  int64_t sparse_dims_to_sum_size = 0;
+  auto sparse_dims_to_keep_v = std::vector<int64_t>();
+  auto dense_dims_to_sum_v = std::vector<int64_t>();
+  for (int64_t d = 0; d < input_dim; d++) {
+    if (dims_to_sum_b[d]) {
+      if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
+      else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
+    }
+    else {
+      if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
+    }
+  }
+
+  const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
+  const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+  const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
+
+  if (sum_all_sparse_dim) {
+    AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be dense since all sparse dims are summed");
+    auto grad_input_values = grad_;
+    auto expand_size = input_values.sizes().vec();
+    if (sum_dense_dim) {
+      auto dense_expand_size = std::vector<int64_t>(expand_size);
+      dense_expand_size.erase(dense_expand_size.begin());
+      AT_ASSERT(dense_expand_size.size() == (input_values.dim() - 1));
+      for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1);  // -1 since grad has no nnz dim
+      grad_input_values = grad_input_values.expand(dense_expand_size);
+    }
+    grad_input_values = grad_input_values.expand(expand_size).clone();
+    return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, input.options().dtype(grad_.dtype())); // convert to grad dtype
+  }
+  else {
+    AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
+    auto grad = grad_.coalesce();
+    LongTensor grad_indices = grad._indices();
+    Tensor grad_values = grad._values();
+    const int64_t grad_sparse_dim = grad.sparse_dim();
+    const int64_t grad_nnz = grad._nnz();
+
+    Tensor grad_values_expand = grad_values;
+    if (sum_dense_dim) {
+      auto expand_size = input_values.sizes().vec();
+      if (sum_sparse_dim) expand_size[0] = grad_values.size(0);
+      for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
+      grad_values_expand = grad_values_expand.expand(expand_size).clone();
+    }
+
+    Tensor grad_input_values;
+    if (sum_sparse_dim) {
+      // see NOTE [ sparse.sum() backward ]
+      grad_input_values = at::zeros_like(input_values, grad_values.options());
+
+      // get flatten indices for grad and input
+      auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
+      std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
+
+      auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
+      auto grad_indices_1D_accessor = grad_indices_1D.accessor<int64_t, 1>();
+      auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
+      auto input_indices_1D_accessor = input_indices_1D.accessor<int64_t, 1>();
+
+      // binary search to find matching indices
+      int64_t i;
+      #pragma omp parallel for private(i)
+      for (i = 0; i < input_nnz; i++) {
+        int64_t input_idx = input_indices_1D_accessor[i];
+        int64_t l = 0, r = grad_nnz - 1;
+        while (l <= r) {
+          int64_t m = l + (r - l) / 2;
+          if (grad_indices_1D_accessor[m] == input_idx) {
+            grad_input_values[i].copy_(grad_values_expand[m]);
+            break;
+          }
+          if (grad_indices_1D_accessor[m] < input_idx) {
+            l = m + 1;
+          }
+          else {
+            r = m - 1;
+          }
+        }
+      }
+    }
+    else {
+      grad_input_values = grad_values_expand;
+    }
+    return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options());
+  }
+}
+
 }} // namespace at::native
index eef8400..8113728 100644 (file)
@@ -6,19 +6,27 @@
 #include <ATen/native/sparse/cuda/SparseCUDABlas.cuh>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <ATen/cuda/detail/IndexUtils.cuh>
+#include "ATen/WrapDimUtilsMulti.h"
 
 #include <THC/THCTensorMathPointwise.cuh>
 #include <THC/THCThrustAllocator.cuh>
+
 #include <thrust/device_ptr.h>
 #include <thrust/sequence.h>
+#include <thrust/binary_search.h>
+#include <thrust/sort.h>
 #include <thrust/system/cuda/execution_policy.h>
 
+#include <bitset>
+
 #define I_INFO(tensor) cuda::detail::getTensorInfo<int64_t, uint64_t>(tensor)
 #define V_INFO(tensor) cuda::detail::getTensorInfo<scalar_t, uint64_t>(tensor)
 
 namespace at { namespace native {
 
 using namespace at::sparse;
+using at::cuda::detail::TensorInfo;
+using at::cuda::detail::getTensorInfo;
 
 // --------------------------------------------------------------------
 // Utility functions
@@ -466,4 +474,168 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons
   return r_._coalesced_(true);
 }
 
+// --------------------------------------------------------------------
+// sparse.sum() backward
+//
+// see NOTE [ sparse.sum() backward ]
+// --------------------------------------------------------------------
+template <typename scalar_t>
+__global__ void _sparse_sum_backward_cuda_kernel(
+  int64_t total_threads,
+  const TensorInfo<int64_t, int64_t> grad_indices_ti,
+  const TensorInfo<int64_t, int64_t> input_indices_ti,
+  const TensorInfo<int64_t, int64_t> input_indices_pos_ti,
+  const TensorInfo<scalar_t, int64_t> grad_values_expand_ti,
+  TensorInfo<scalar_t, int64_t> grad_input_values_ti
+) {
+  const int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
+  if (i >= total_threads) return;
+  const int64_t j = input_indices_pos_ti.data[i];
+
+  bool has_match = false;
+  if (grad_indices_ti.data[j] == input_indices_ti.data[i]) {
+    has_match = true;
+  }
+
+  int64_t grad_input_values_stride0 = grad_input_values_ti.strides[0];
+  int64_t out_start = i * grad_input_values_stride0;
+  int64_t out_end = (i + 1) * grad_input_values_stride0;
+  int64_t in_start = j * grad_values_expand_ti.strides[0];
+
+  if (has_match) {
+    for (int64_t out_i = out_start, in_i = in_start; out_i < out_end; out_i++, in_i++) {
+      grad_input_values_ti.data[out_i] = grad_values_expand_ti.data[in_i];
+    }
+  }
+  else {
+    for (int64_t out_i = out_start; out_i < out_end; out_i++) {
+      grad_input_values_ti.data[out_i] = scalar_t(0);
+    }
+  }
+}
+
+Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_, IntList dims_to_sum) {
+  AT_CHECK(grad_.is_cuda(), "_sparse_sum_backward_cuda: expected 'grad_' to be CUDA tensor, but got CPU tensor");
+  AT_CHECK(input_.is_cuda(), "_sparse_sum_backward_cuda: expected 'input_' to be CUDA tensor, but got CPU tensor");
+
+  auto input = input_.coalesce();
+  const int64_t input_dim = input.dim();
+  auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
+  auto dims_to_sum_v = dims_to_sum.vec();
+  maybe_wrap_dims(dims_to_sum_v, input_dim);
+
+  LongTensor input_indices = input._indices();
+  Tensor input_values = input._values();
+  IntList input_sizes = input.sizes();
+  const int64_t input_sparse_dim = input.sparse_dim();
+  const int64_t input_dense_dim = input.dense_dim();
+  const int64_t input_nnz = input._nnz();
+
+  int64_t sparse_dims_to_sum_size = 0;
+  auto sparse_dims_to_keep_v = std::vector<int64_t>();
+  auto dense_dims_to_sum_v = std::vector<int64_t>();
+  for (int64_t d = 0; d < input_dim; d++) {
+    if (dims_to_sum_b[d]) {
+      if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
+      else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
+    }
+    else {
+      if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
+    }
+  }
+
+  const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
+  const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
+  const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
+
+  if (sum_all_sparse_dim) {
+    AT_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad Tensor to be dense since all sparse dims are summed");
+    auto grad_input_values = grad_;
+    auto expand_size = input_values.sizes().vec();
+    if (sum_dense_dim) {
+      auto dense_expand_size = std::vector<int64_t>(expand_size);
+      dense_expand_size.erase(dense_expand_size.begin()); // remove nnz dim
+      for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim
+      grad_input_values = grad_input_values.expand(dense_expand_size);
+    }
+    grad_input_values = grad_input_values.expand(expand_size).clone();
+    return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values,  input.options().dtype(grad_.dtype())); // convert to grad dtype
+  }
+  else {
+    AT_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cuda: expected grad_ Tensor to be sparse, but got dense");
+    auto grad = grad_.coalesce();
+    LongTensor grad_indices = grad._indices();
+    Tensor grad_values = grad._values();
+    const int64_t grad_sparse_dim = grad.sparse_dim();
+    const int64_t grad_nnz = grad._nnz();
+
+    Tensor grad_values_expand = grad_values;
+    if (sum_dense_dim) {
+      auto expand_size = input_values.sizes().vec();
+      if (sum_sparse_dim) expand_size[0] = grad_values.size(0); // update nnz
+      for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
+      grad_values_expand = grad_values_expand.expand(expand_size).clone();
+    }
+
+    Tensor grad_input_values;
+    if (!sum_sparse_dim) {
+      grad_input_values = grad_values_expand;
+    }
+    else {
+      int curDevice = -1;
+      cudaGetDevice(&curDevice);
+      cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+      auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+      auto policy = thrust::cuda::par(allocator).on(stream);
+      typedef thrust::device_ptr<int64_t> thrust_ptr;
+
+      grad_input_values = at::empty_like(input_values, grad_values.options());
+      AT_ASSERT(grad_input_values.is_cuda());
+
+      // get 1D indices
+      auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
+      std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
+
+      auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
+      auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
+      thrust_ptr grad_indices_iter(grad_indices_1D.data<int64_t>());
+      thrust_ptr input_indices_iter(input_indices_1D.data<int64_t>());
+
+      // store lower_bound of input indices at grad indices
+      LongTensor input_indices_pos = at::empty_like(input_indices_1D);
+      thrust_ptr input_indices_pos_iter(input_indices_pos.data<int64_t>());
+      thrust::lower_bound(policy,
+                          grad_indices_iter, grad_indices_iter + grad_nnz,
+                          input_indices_iter, input_indices_iter + input_nnz,
+                          input_indices_pos_iter);
+
+      // config to run cuda kernel
+      int64_t total_threads = input_nnz;
+      const dim3 block = dim3(std::min(static_cast<int64_t>(cuda::getApplyBlock().x), total_threads));
+      dim3 grid;
+      AT_CHECK(cuda::getApplyGrid(total_threads, grid, curDevice), "_sparse_sum_backward_cuda: input too large or too many dimensions");
+
+      auto grad_indices_ti = getTensorInfo<int64_t, int64_t>(grad_indices_1D);
+      auto input_indices_ti = getTensorInfo<int64_t, int64_t>(input_indices_1D);
+      auto input_indices_pos_ti = getTensorInfo<int64_t, int64_t>(input_indices_pos);
+
+      AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_values.type(), "_sparse_sum_backward_cuda", [&] {
+        auto grad_values_expand_ti = getTensorInfo<scalar_t, int64_t>(grad_values_expand);
+        auto grad_input_values_ti = getTensorInfo<scalar_t, int64_t>(grad_input_values);
+
+        _sparse_sum_backward_cuda_kernel<scalar_t><<<grid, block, 0, stream>>>(
+          total_threads,
+          grad_indices_ti,
+          input_indices_ti,
+          input_indices_pos_ti,
+          grad_values_expand_ti,
+          grad_input_values_ti
+        );
+      });
+    }
+
+    return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(), grad_input_values, grad.options());
+  }
+}
+
 }} // namespace at::native
index 4b8d5bf..1e6afde 100644 (file)
@@ -60,6 +60,13 @@ An empty sparse tensor can be constructed by specifying its size:
     and values:
     [torch.FloatTensor with no dimension]
 
+SparseTensor has the following invariants:
+  1. sparse_dim + dense_dim = len(SparseTensor.shape)
+  2. SparseTensor._indices().shape = (sparse_dim, nnz)
+  3. SparseTensor._values().shape = (nnz, SparseTensor.shape[sparse_dim:])
+Since SparseTensor._indices() is always a 2D tensor, the smallest sparse_dim = 1.
+Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense tensor.
+
 .. note::
 
     Our sparse tensor format permits *uncoalesced* sparse tensors, where
@@ -134,3 +141,4 @@ Functions
 ----------------------------------
 
 .. autofunction:: torch.sparse.addmm
+.. autofunction:: torch.sparse.sum
index 39c7e88..60fc784 100644 (file)
@@ -935,6 +935,64 @@ class TestSparse(TestCase):
         test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0])
         test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0])
 
+    @skipIfRocm
+    def test_sparse_sum(self):
+
+        def run_tests(S, td=None):
+            D = S.coalesce().to_dense().detach().requires_grad_(True)
+            mask = (D == 0)
+            if td is None:
+                S_sum = torch.sparse.sum(S)
+                D_sum = D.sum()
+                self.assertEqual(S_sum, D_sum)
+                S_sum.backward()
+                D_sum.backward()
+                D_grad = D.grad.masked_fill_(mask, 0)
+                self.assertEqual(S.grad.to_dense(), D_grad)
+            else:
+                S_sum = torch.sparse.sum(S, td)
+                D_sum = D.sum(td)
+                self.assertEqual(S_sum.to_dense() if S_sum.is_sparse else S_sum, D_sum)
+                S_sum.backward(S_sum.detach())
+                S_grad = S.grad
+                data = S_sum.to_dense().detach() if S_sum.is_sparse else S_sum.detach()
+                D_sum.backward(data)
+                D_grad = D.grad.masked_fill_(mask, 0)
+                S_grad_dense = S_grad.coalesce().to_dense() if S_grad.is_sparse else S_grad
+                self.assertEqual(S_grad_dense, D_grad)
+
+        nnz = 10
+        sparse_dims = 2
+        with_size = [5, 5, 1, 4]  # use a dense dim = 1 to test for squeeze
+        test_dims = []
+        for i in range(1, 5):
+            test_dims += itertools.combinations(range(len(with_size)), i)
+
+        # not support SparseTensor.sum()
+        S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+        self.assertRaises(RuntimeError, lambda: S.sum())
+
+        # raise for incorrect input dims
+        self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, 5))
+        self.assertRaises(RuntimeError, lambda: torch.sparse.sum(S, [0, 0]))
+
+        # sum an empty tensor
+        empty_S = torch.sparse_coo_tensor(size=with_size)
+        self.assertRaises(RuntimeError, lambda: torch.sparse.sum(empty_S, [0]))
+        self.assertEqual(torch.sparse.sum(empty_S), torch.tensor(0,))
+        empty_S.requires_grad_(True)
+        empty_S_sum = torch.sparse.sum(empty_S)
+        empty_S_sum.backward()
+        self.assertEqual(empty_S.grad.to_dense(), empty_S.clone().detach().to_dense())
+
+        # test values().sum()
+        S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+        run_tests(S.requires_grad_(True))
+
+        for test_dim in test_dims:
+            S = self._gen_sparse(sparse_dims, nnz, with_size)[0]
+            run_tests(S.requires_grad_(True), test_dim)
+
     def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v=None):
         shape = shape_i + (shape_v or [])
         x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape)
index 0736a46..30432a2 100644 (file)
 - name: _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, IntList size, Tensor indices, Tensor values, TensorOptions options)
   values: sparse_constructor_values_backward(grad, indices, values.sizes())
 
+- name: _sparse_sum(Tensor self, IntList dim)
+  self: at::_sparse_sum_backward(grad, self, dim)
+
 - name: _standard_gamma(Tensor self, Generator generator)
   self: grad * _standard_gamma_grad(self, result)
 
index e2ab34f..07553e9 100644 (file)
@@ -3,12 +3,13 @@ import torch
 
 __all__ = [
     'addmm',
+    'sum',
 ]
 
 
 def addmm(mat, mat1, mat2, beta=1, alpha=1):
     r"""
-    This function does exact same thing as :meth:`~Torch.addmm` in the forward,
+    This function does exact same thing as :func:`torch.addmm` in the forward,
     except that it supports backward for coalesced sparse matrix `mat1`.
 
     Args:
@@ -19,3 +20,70 @@ def addmm(mat, mat1, mat2, beta=1, alpha=1):
         alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
     """
     return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
+
+
+def sum(input, dim=None, dtype=None):
+    r"""
+    Returns the sum of each row of SparseTensor :attr:`input` in the given
+    dimensions :attr:`dim`. If :attr::`dim` is a list of dimensions,
+    reduce over all of them. When sum over all ``sparse_dim``, this method
+    returns a Tensor instead of SparseTensor.
+
+    All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
+    tensor having :attr::`dim` fewer dimensions than :attr:`input`.
+
+    During backward, only gradients at ``nnz`` locations of :attr:`input`
+    will propagate back. Note that the gradients of :attr:`input` is coalesced.
+
+    Args:
+        input (Tensor): the input SparseTensor
+        dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
+            over all dims.
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
+            Default: dtype of :attr:`input`.
+
+    Example::
+
+        >>> nnz = 3
+        >>> dims = [5, 5, 2, 3]
+        >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
+                           torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
+        >>> V = torch.randn(nnz, dims[2], dims[3])
+        >>> size = torch.Size(dims)
+        >>> S = torch.sparse_coo_tensor(I, V, size)
+        >>> S
+        tensor(indices=tensor([[2, 0, 3],
+                               [2, 4, 1]]),
+               values=tensor([[[-0.6438, -1.6467,  1.4004],
+                               [ 0.3411,  0.0918, -0.2312]],
+
+                              [[ 0.5348,  0.0634, -2.0494],
+                               [-0.7125, -1.0646,  2.1844]],
+
+                              [[ 0.1276,  0.1874, -0.6334],
+                               [-1.9682, -0.5340,  0.7483]]]),
+               size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
+
+        # when sum over only part of sparse_dims, return a SparseTensor
+        >>> torch.sparse.sum(S, [1, 3])
+        tensor(indices=tensor([[0, 2, 3]]),
+               values=tensor([[-1.4512,  0.4073],
+                              [-0.8901,  0.2017],
+                              [-0.3183, -1.7539]]),
+               size=(5, 2), nnz=3, layout=torch.sparse_coo)
+
+        # when sum over all sparse dim, return a dense Tensor
+        # with summed dims squeezed
+        >>> torch.sparse.sum(S, [0, 1, 3])
+        tensor([-2.6596, -1.1450])
+    """
+    if dtype is None:
+        if dim:
+            return torch._sparse_sum(input, dim)
+        else:
+            return torch._sparse_sum(input)
+    else:
+        if dim:
+            return torch._sparse_sum(input, dim, dtype=dtype)
+        else:
+            return torch._sparse_sum(input, dtype=dtype)