Compute cuda reduction buffer size in elements (#63969)
authorNatalia Gimelshein <ngimel@fb.com>
Thu, 26 Aug 2021 01:17:10 +0000 (18:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 01:18:37 +0000 (18:18 -0700)
Summary:
Resubmit of https://github.com/pytorch/pytorch/issues/63885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63969

Reviewed By: mruberry

Differential Revision: D30549423

Pulled By: ngimel

fbshipit-source-id: b16d25030d44ced789c125a333d72b02a8f45067

aten/src/ATen/native/cuda/Reduce.cuh
test/test_reductions.py

index 8c42306..b460045 100644 (file)
@@ -919,10 +919,11 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
     // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
     // when accumulation in output is not possible.
     if (!can_accumulate_in_output && !can_use_32bit_indexing) {
-      int64_t output_memory_size = 1;
+      int64_t output_memory_size = iter.element_size(0);
       for (int dim = 0; dim < iter.ndim(); dim++) {
         output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
       }
+      output_memory_size /= iter.element_size(0); //iter.strides is in bytes
       owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
                                                  sizeof(out_scalar_t),
                                                  (char*) iter.data_ptr(0),
index 1497ed6..c1da0f0 100644 (file)
@@ -1788,7 +1788,7 @@ class TestReductions(TestCase):
         run_test(torch.zeros(64, 61, dtype=dtype, device=device))
         run_test(torch.zeros(64, 1, dtype=dtype, device=device))
 
-    @slowTest
+    @onlyCUDA
     def test_argminmax_large_axis(self, device):
         # Regression test for gh-32863
         x = torch.zeros(2**31, device=device, dtype=torch.int8)