improve dim sort performance (#19379)
authorNatalia Gimelshein <ngimelshein@nvidia.com>
Fri, 19 Apr 2019 05:20:44 +0000 (22:20 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 05:25:08 +0000 (22:25 -0700)
Summary:
We are already using custom comparators for sorting (for a good reason), but are still making 2 sorting passes - global sort and stable sorting to bring values into their slices. Using a custom comparator to sort within a slice allows us to avoid second sorting pass and brings up to 50% perf improvement.
t-vi I know you are moving sort to ATen, and changing THC is discouraged, but #18350 seems dormant. I'm fine with #18350 landing first, and then I can put in these changes.
cc umanwizard for review.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19379

Differential Revision: D15011019

Pulled By: soumith

fbshipit-source-id: 48e5f5aef51789b166bb72c75b393707a9aed57c

aten/src/THC/THCTensorSort.cuh
aten/src/THC/generic/THCTensorSort.cu

index def03a3..ffda23a 100644 (file)
@@ -29,6 +29,37 @@ struct ThrustLTOp {
   }
 };
 
+template <typename T, typename IndT, bool handleNaN = true>
+struct ThrustSliceGTOp {
+ThrustSliceGTOp(int64_t size) : sliceSize(size) {}
+  __device__ bool operator()(const thrust::tuple<int64_t, T>& lhs, const thrust::tuple<int64_t, T>& rhs) const {
+    IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize;
+    IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize;
+    if (segA != segB)
+        return segA < segB;
+    else
+        return (handleNaN && THCNumerics<T>::isnan(thrust::get<1>(lhs)) && !THCNumerics<T>::isnan(thrust::get<1>(rhs))) || THCNumerics<T>::gt(thrust::get<1>(lhs), thrust::get<1>(rhs));
+  }
+  const IndT sliceSize;
+};
+
+template <typename T, typename IndT, bool handleNaN = true>
+struct ThrustSliceLTOp {
+ThrustSliceLTOp(int64_t size) : sliceSize(size) {}
+  __device__ bool operator()(const thrust::tuple<int64_t, T>& lhs, const thrust::tuple<int64_t, T>& rhs) const {
+    IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize;
+    IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize;
+    if (segA != segB)
+        return segA < segB;
+    else
+        return (handleNaN && THCNumerics<T>::isnan(thrust::get<1>(rhs)) && !THCNumerics<T>::isnan(thrust::get<1>(lhs))) || THCNumerics<T>::lt(thrust::get<1>(lhs), thrust::get<1>(rhs));
+  }
+  const IndT sliceSize;
+};
+
+
+
+
 // `base` is the base address of a tensor
 // For each slice (defined as a linear point of `out`, from 0 ->
 // (sliceSize - 1) * sliceStride, we fill that slice from `0` to
@@ -55,22 +86,6 @@ fillSliceWithIndex(TensorInfo<int64_t, IndexType> out,
   }
 }
 
-// For slice sorting in Thrust; extracts a slice index from a linear
-// index and uses that for comparison
-struct SliceComp {
-  SliceComp(int64_t size) : sliceSize(size) {}
-
-  __device__ bool operator()(const int64_t& a, const int64_t& b) const {
-    // Since the slices are guaranteed to be innermost,
-    // the segment is just via int64_t division
-    int64_t segA = a / sliceSize;
-    int64_t segB = b / sliceSize;
-    return segA < segB;
-  }
-
-  const int64_t sliceSize;
-};
-
 // For sorting in Thurst; extracts a within-slice index from a linear index
 struct GlobalIndexToPerSliceIndex {
   GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
index f9b1db8..36ffacf 100644 (file)
@@ -220,40 +220,39 @@ void THCTensor_(sortViaThrust)(THCState* state,
 
   // Fill the indices with a global index across all slices
   thrust::counting_iterator<int64_t> countIter(0);
-
   thrust::copy(
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
     thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
     countIter, countIter + totalElements, indexIter);
-
-  // First, we sort globally (across all slices) according to key
-  // (the values we're sorting)
-  if (dir) {
-    thrust::stable_sort_by_key(
+    auto begin = thrust::make_zip_iterator(thrust::make_tuple(indexIter, keyIter));
+  if (dir){
+    if (totalElements < INT_MAX)
+       thrust::sort(
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
-      thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
-      keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t, true>());
+       begin, begin + totalElements, ThrustSliceGTOp<scalar_t, int, true>(sliceSize));
+    else
+       thrust::sort(
+#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
+       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+#endif
+       begin, begin + totalElements, ThrustSliceGTOp<scalar_t, int64_t, true>(sliceSize));
   } else {
-    thrust::stable_sort_by_key(
+    if (totalElements < INT_MAX)
+       thrust::sort(
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
-      thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
-      keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t, true>());
-  }
-
-  // Then, re-sort according to slice that each index is
-  // in. This completes the segment sort in Thrust, since we're
-  // stably sorting here, preserving the relative order of values
-  // per each slice
-  thrust::stable_sort_by_key(
+       begin, begin + totalElements, ThrustSliceLTOp<scalar_t, int, true>(sliceSize));
+    else
+       thrust::sort(
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
-    thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
-    indexIter, indexIter + totalElements, keyIter,
-    SliceComp(sliceSize));
-
+       begin, begin + totalElements, ThrustSliceLTOp<scalar_t, int64_t, true>(sliceSize));
+  }
   // Translate the global integer 0-based index to a per-slice real
   // Lua index
   thrust::for_each(
@@ -268,7 +267,6 @@ void THCTensor_(sortViaThrust)(THCState* state,
     THCTensor_(transpose)(state, trContigKey, NULL, dim, nDims - 1);
     THCudaLongTensor_transpose(state, trContigIndices, NULL, dim, nDims - 1);
   }
-
   // Then copy back to the expected output
   THCTensor_(freeCopyTo)(state, trContigKey, sorted);
   THCudaLongTensor_freeCopyTo(state, trContigIndices, indices);