}
+namespace {
+
+// Segmented sort by full sort algorithm:.
+// Say we are sorting a (2, 3) tensor. We have in flattened form:
+// values 0.4 1.2 5.3 6.2 1.3 2.3
+// indices 0 1 2 0 1 2
+// segment_id 0 0 0 1 1 1
+
+// First we sort by values, globally:
+// values 6.2 5.3 2.3 1.2 1.3 0.4
+// indices 0 2 2 1 1 0
+// segment_id 1 0 1 0 1 0
+
+// Then we stable sort by segment id:
+// values 5.3 1.2 0.4 6.2 2.3 1.3
+// indices 2 1 0 0 2 1
+// segment_id 0 0 0 1 1 1
+
+// This method can only work if the slice we are sorting (`dim`) is
+// innermost, and both values and indices are contiguous. We do this
+// by re-arranging the input into this form as needed, which will
+// unfortunately allocate memory if the request is not in this form.
+// Vectorized sort is slower than iterated sort if the number of
+// slices is small (since we're sorting twice, instead of invoking a
+// smaller sort `numSlices` times), but the cub sort
+// implementation here is a catch-all, so we're not looking for
+// efficiency, but instead correctness.
+
+template<typename scalar_t>
+__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) {
+ CUDA_KERNEL_LOOP(i, nsegments * nsort) {
+ int segment = i / nsort;
+ int j = i % nsort;
+
+ int offset = segment * nsort;
+ const scalar_t *in_ = in + offset;
+ scalar_t *out_ = out + offset;
+ int64_t *index_ = index + offset;
+ const int2 *i_s_ptr_ = i_s_ptr + offset;
+
+ int idx = i_s_ptr_[j].y;
+ index_[j] = idx;
+ out_[j] = in_[idx];
+ }
+}
+
+template<typename scalar_t>
+inline void segmented_sort_pairs_by_full_sort(
+ int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices,
+ const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr
+) {
+ int64_t segment_bits = std::max<int64_t>(1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
+
+ auto int_options = indices.options().dtype(kInt);
+ auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options);
+ indices_and_segment.select(-1, 0).copy_( // segment id
+ at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort}));
+ indices_and_segment.select(-1, 1).copy_( // reverse indices
+ at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort}));
+
+ auto i_s_ptr = reinterpret_cast<int2 *>(indices_and_segment.data_ptr<int>());
+ auto indices_and_segment2 = at::empty_like(indices_and_segment);
+ auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());
+
+ at::cuda::cub::sort_pairs<scalar_t, int2>(
+ self_ptr, nullptr, i_s_ptr, i_s_ptr2,
+ n, descending);
+
+ TORCH_INTERNAL_ASSERT(segment_bits <= 32);
+
+ // sort on lower 32bits, i.e. segment index
+ at::cuda::cub::sort_keys<int64_t>(
+ reinterpret_cast<int64_t *>(i_s_ptr2), reinterpret_cast<int64_t *>(i_s_ptr),
+ n, false, 0, segment_bits);
+
+ sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
+ self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
+}
+
+} // namespace
+
// We perform a segmented sort in cub with inputs that have
// more than 1024/2048 elements along the selected dimension.
// Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
int64_t n = std::min(remaining, nbatch);
int64_t nsegments = n / nsort;
- auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
-
- at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
- reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
- offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
+ if (nsegments < 128) {
+ segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending,
+ indices, self_ptr, values_ptr, indices_ptr);
+ } else {
+ auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
+ at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
+ reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
+ offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
+ }
remaining -= n;
self_ptr += n;