Bring back old algorithm for sorting on small number of segments (#64127)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Mon, 30 Aug 2021 19:25:29 +0000 (12:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 30 Aug 2021 19:30:50 +0000 (12:30 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63456
The code was copy-pasted from the previous commit without modification.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64127

Reviewed By: mruberry

Differential Revision: D30632090

Pulled By: ngimel

fbshipit-source-id: 58bbdd9b0423f01d4e65e2ec925ad9a3f88efc9b

aten/src/ATen/native/cuda/Sort.cu

index 83fce65..9cb32bc 100644 (file)
@@ -207,6 +207,87 @@ struct offset_t {
 
 }
 
+namespace {
+
+// Segmented sort by full sort algorithm:.
+// Say we are sorting a (2, 3) tensor. We have in flattened form:
+// values       0.4 1.2 5.3 6.2 1.3 2.3
+// indices        0   1   2   0   1   2
+// segment_id     0   0   0   1   1   1
+
+// First we sort by values, globally:
+// values       6.2 5.3 2.3 1.2 1.3 0.4
+// indices        0   2   2   1   1   0
+// segment_id     1   0   1   0   1   0
+
+// Then we stable sort by segment id:
+// values       5.3 1.2 0.4 6.2 2.3 1.3
+// indices        2   1   0   0   2   1
+// segment_id     0   0   0   1   1   1
+
+// This method can only work if the slice we are sorting (`dim`) is
+// innermost, and both values and indices are contiguous. We do this
+// by re-arranging the input into this form as needed, which will
+// unfortunately allocate memory if the request is not in this form.
+// Vectorized sort is slower than iterated sort if the number of
+// slices is small (since we're sorting twice, instead of invoking a
+// smaller sort `numSlices` times), but the cub sort
+// implementation here is a catch-all, so we're not looking for
+// efficiency, but instead correctness.
+
+template<typename scalar_t>
+__global__ void sort_postprocess_kernel(const scalar_t *in, scalar_t *out, int64_t *index, const int2 *i_s_ptr, int nsegments, int nsort) {
+  CUDA_KERNEL_LOOP(i, nsegments * nsort) {
+    int segment = i / nsort;
+    int j = i % nsort;
+
+    int offset = segment * nsort;
+    const scalar_t *in_ = in + offset;
+    scalar_t *out_ = out + offset;
+    int64_t *index_ = index + offset;
+    const int2 *i_s_ptr_ = i_s_ptr + offset;
+
+    int idx = i_s_ptr_[j].y;
+    index_[j] = idx;
+    out_[j] = in_[idx];
+  }
+}
+
+template<typename scalar_t>
+inline void segmented_sort_pairs_by_full_sort(
+  int64_t nsegments, int64_t nsort, int64_t n, bool descending, const Tensor &indices,
+  const scalar_t *self_ptr, scalar_t *values_ptr, int64_t *indices_ptr
+) {
+  int64_t segment_bits = std::max<int64_t>(1L, static_cast<int64_t>(std::ceil(std::log2(nsegments))));
+
+  auto int_options = indices.options().dtype(kInt);
+  auto indices_and_segment = at::empty({nsegments, nsort, 2}, int_options);
+  indices_and_segment.select(-1, 0).copy_(  // segment id
+    at::arange(nsegments, int_options).view({nsegments, 1}).expand({nsegments, nsort}));
+  indices_and_segment.select(-1, 1).copy_(  // reverse indices
+    at::arange(nsort, int_options).view({1, nsort}).expand({nsegments, nsort}));
+
+  auto i_s_ptr = reinterpret_cast<int2 *>(indices_and_segment.data_ptr<int>());
+  auto indices_and_segment2 = at::empty_like(indices_and_segment);
+  auto i_s_ptr2 = reinterpret_cast<int2 *>(indices_and_segment2.data_ptr<int>());
+
+  at::cuda::cub::sort_pairs<scalar_t, int2>(
+    self_ptr, nullptr, i_s_ptr, i_s_ptr2,
+    n, descending);
+
+  TORCH_INTERNAL_ASSERT(segment_bits <= 32);
+
+  // sort on lower 32bits, i.e. segment index
+  at::cuda::cub::sort_keys<int64_t>(
+    reinterpret_cast<int64_t *>(i_s_ptr2), reinterpret_cast<int64_t *>(i_s_ptr),
+    n, false, 0, segment_bits);
+
+  sort_postprocess_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
+    self_ptr, values_ptr, indices_ptr, i_s_ptr, nsegments, nsort);
+}
+
+}  // namespace
+
 // We perform a segmented sort in cub with inputs that have
 // more than 1024/2048 elements along the selected dimension.
 // Otherwise, we do an inplace bitonic sort (see sortKeyValueInplace).
@@ -349,11 +430,15 @@ std::tuple<Tensor &,Tensor &> sort_out_stable_cuda(const Tensor & self, c10::opt
         int64_t n = std::min(remaining, nbatch);
         int64_t nsegments = n / nsort;
 
-        auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
-
-        at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
-          reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
-          offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
+        if (nsegments < 128) {
+          segmented_sort_pairs_by_full_sort(nsegments, nsort, n, descending,
+            indices, self_ptr, values_ptr, indices_ptr);
+        } else {
+          auto reverse_indices = at::arange(nsort, indices.options()).view({1, nsort}).expand({nsegments, nsort}).contiguous();
+          at::cuda::cub::segmented_sort_pairs(self_ptr, values_ptr,
+            reverse_indices.data_ptr<int64_t>(), indices_ptr, n, nsegments,
+            offset_t{(int)nsort, 0}, offset_t{(int)nsort, 1}, descending);
+        }
 
         remaining -= n;
         self_ptr += n;