From a8ffe81b2c3123926354b4ec2001693b38daa80d Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 30 Aug 2021 12:25:29 -0700 Subject: [PATCH] Bring back old algorithm for sorting on small number of segments (#64127) Summary: Fixes https://github.com/pytorch/pytorch/issues/63456 The code was copy-pasted from the previous commit without modification. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64127 Reviewed By: mruberry Differential Revision: D30632090 Pulled By: ngimel fbshipit-source-id: 58bbdd9b0423f01d4e65e2ec925ad9a3f88efc9b --- aten/src/ATen/native/cuda/Sort.cu | 95 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index 83fce65..9cb32bc 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -207,6 +207,87 @@ struct offset_t { } +namespace { + +// Segmented sort by full sort algorithm:. +// Say we are sorting a (2, 3) tensor. We have in flattened form: +// values 0.4 1.2 5.3 6.2 1.3 2.3 +// indices 0 1 2 0 1 2 +// segment_id 0 0 0 1 1 1 + +// First we sort by values, globally: +// values 6.2 5.3 2.3 1.2 1.3 0.4 +// indices 0 2 2 1 1 0 +// segment_id 1 0 1 0 1 0 + +// Then we stable sort by segment id: +// values 5.3 1.2 0.4 6.2 2.3 1.3 +// indices 2 1 0 0 2 1 +// segment_id 0 0 0 1 1 1 + +// This method can only work if the slice we are sorting (`dim`) is +// innermost, and both values and indices are contiguous. We do this +// by re-arranging the input into this form as needed, which will +// unfortunately allocate memory if the request is not in this form. +// Vectorized sort is slower than iterated sort if the number of +// slices is small (since we're sorting twice, instead of invoking a +// smaller sort `numSlices` times), but the cub sort +// implementation here is a catch-all, so we're not looking for +// efficiency, but instead correctness. + +template +__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) { + CUDA_KERNEL_LOOP(i, nsegments * nsort) { + int segment = i / nsort; + int j = i % nsort; + + int offset = segment * nsort; + const scalar_t *in_ = in + offset; + scalar_t *out_ = out + offset; + int64_t *index_ = index + offset; + const int2 *i_s_ptr_ = i_s_ptr + offset; + + int idx = i_s_ptr_[j].y; + index_[j] = idx; + out_[j] = in_[idx]; + } +} + +template +inline void segmented_sort_pairs_by_full_sort( + int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices, + const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr +) { + int64_t segment_bits = std::max(1L, static_cast(std::ceil(std::log2(nsegments)))); + + auto int_options = indices.options().dtype(kInt); + auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options); + indices_and_segment.select(-1, 0).copy_( // segment id + at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort})); + indices_and_segment.select(-1, 1).copy_( // reverse indices + at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort})); + + auto i_s_ptr = reinterpret_cast(indices_and_segment.data_ptr()); + auto indices_and_segment2 = at::empty_like(indices_and_segment); + auto i_s_ptr2 = reinterpret_cast(indices_and_segment2.data_ptr()); + + at::cuda::cub::sort_pairs( + self_ptr, nullptr, i_s_ptr, i_s_ptr2, + n, descending); + + TORCH_INTERNAL_ASSERT(segment_bits <= 32); + + // sort on lower 32bits, i.e. segment index + at::cuda::cub::sort_keys( + reinterpret_cast(i_s_ptr2), reinterpret_cast(i_s_ptr), + n, false, 0, segment_bits); + + sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>( + self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort); +} + +} // namespace + // We perform a segmented sort in cub with inputs that have // more than 1024/2048 elements along the selected dimension. // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace). @@ -349,11 +430,15 @@ std::tuple sort_out_stable_cuda(const Tensor & self, c10::opt int64_t n = std::min(remaining, nbatch); int64_t nsegments = n / nsort; - auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous(); - - at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, - reverse_indices.data_ptr(), indices_ptr, n, nsegments, - offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); + if (nsegments < 128) { + segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending, + indices, self_ptr, values_ptr, indices_ptr); + } else { + auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous(); + at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr, + reverse_indices.data_ptr(), indices_ptr, n, nsegments, + offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending); + } remaining -= n; self_ptr += n; -- 2.7.4