From f0d98199fb1f20b8a22a645f50b5410d9392a864 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Thu, 18 Apr 2019 22:20:44 -0700 Subject: [PATCH] improve dim sort performance (#19379) Summary: We are already using custom comparators for sorting (for a good reason), but are still making 2 sorting passes - global sort and stable sorting to bring values into their slices. Using a custom comparator to sort within a slice allows us to avoid second sorting pass and brings up to 50% perf improvement. t-vi I know you are moving sort to ATen, and changing THC is discouraged, but #18350 seems dormant. I'm fine with #18350 landing first, and then I can put in these changes. cc umanwizard for review. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19379 Differential Revision: D15011019 Pulled By: soumith fbshipit-source-id: 48e5f5aef51789b166bb72c75b393707a9aed57c --- aten/src/THC/THCTensorSort.cuh | 47 +++++++++++++++++++++++------------ aten/src/THC/generic/THCTensorSort.cu | 44 ++++++++++++++++---------------- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/aten/src/THC/THCTensorSort.cuh b/aten/src/THC/THCTensorSort.cuh index def03a3..ffda23a 100644 --- a/aten/src/THC/THCTensorSort.cuh +++ b/aten/src/THC/THCTensorSort.cuh @@ -29,6 +29,37 @@ struct ThrustLTOp { } }; +template +struct ThrustSliceGTOp { +ThrustSliceGTOp(int64_t size) : sliceSize(size) {} + __device__ bool operator()(const thrust::tuple& lhs, const thrust::tuple& rhs) const { + IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize; + IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize; + if (segA != segB) + return segA < segB; + else + return (handleNaN && THCNumerics::isnan(thrust::get<1>(lhs)) && !THCNumerics::isnan(thrust::get<1>(rhs))) || THCNumerics::gt(thrust::get<1>(lhs), thrust::get<1>(rhs)); + } + const IndT sliceSize; +}; + +template +struct ThrustSliceLTOp { +ThrustSliceLTOp(int64_t size) : sliceSize(size) {} + __device__ bool operator()(const thrust::tuple& lhs, const thrust::tuple& rhs) const { + IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize; + IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize; + if (segA != segB) + return segA < segB; + else + return (handleNaN && THCNumerics::isnan(thrust::get<1>(rhs)) && !THCNumerics::isnan(thrust::get<1>(lhs))) || THCNumerics::lt(thrust::get<1>(lhs), thrust::get<1>(rhs)); + } + const IndT sliceSize; +}; + + + + // `base` is the base address of a tensor // For each slice (defined as a linear point of `out`, from 0 -> // (sliceSize - 1) * sliceStride, we fill that slice from `0` to @@ -55,22 +86,6 @@ fillSliceWithIndex(TensorInfo out, } } -// For slice sorting in Thrust; extracts a slice index from a linear -// index and uses that for comparison -struct SliceComp { - SliceComp(int64_t size) : sliceSize(size) {} - - __device__ bool operator()(const int64_t& a, const int64_t& b) const { - // Since the slices are guaranteed to be innermost, - // the segment is just via int64_t division - int64_t segA = a / sliceSize; - int64_t segB = b / sliceSize; - return segA < segB; - } - - const int64_t sliceSize; -}; - // For sorting in Thurst; extracts a within-slice index from a linear index struct GlobalIndexToPerSliceIndex { GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {} diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index f9b1db8..36ffacf 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -220,40 +220,39 @@ void THCTensor_(sortViaThrust)(THCState* state, // Fill the indices with a global index across all slices thrust::counting_iterator countIter(0); - thrust::copy( #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif countIter, countIter + totalElements, indexIter); - - // First, we sort globally (across all slices) according to key - // (the values we're sorting) - if (dir) { - thrust::stable_sort_by_key( + auto begin = thrust::make_zip_iterator(thrust::make_tuple(indexIter, keyIter)); + if (dir){ + if (totalElements < INT_MAX) + thrust::sort( #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - keyIter, keyIter + totalElements, indexIter, ThrustGTOp()); + begin, begin + totalElements, ThrustSliceGTOp(sliceSize)); + else + thrust::sort( +#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), +#endif + begin, begin + totalElements, ThrustSliceGTOp(sliceSize)); } else { - thrust::stable_sort_by_key( + if (totalElements < INT_MAX) + thrust::sort( #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - keyIter, keyIter + totalElements, indexIter, ThrustLTOp()); - } - - // Then, re-sort according to slice that each index is - // in. This completes the segment sort in Thrust, since we're - // stably sorting here, preserving the relative order of values - // per each slice - thrust::stable_sort_by_key( + begin, begin + totalElements, ThrustSliceLTOp(sliceSize)); + else + thrust::sort( #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ - thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), + thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)), #endif - indexIter, indexIter + totalElements, keyIter, - SliceComp(sliceSize)); - + begin, begin + totalElements, ThrustSliceLTOp(sliceSize)); + } // Translate the global integer 0-based index to a per-slice real // Lua index thrust::for_each( @@ -268,7 +267,6 @@ void THCTensor_(sortViaThrust)(THCState* state, THCTensor_(transpose)(state, trContigKey, NULL, dim, nDims - 1); THCudaLongTensor_transpose(state, trContigIndices, NULL, dim, nDims - 1); } - // Then copy back to the expected output THCTensor_(freeCopyTo)(state, trContigKey, sorted); THCudaLongTensor_freeCopyTo(state, trContigIndices, indices); -- 2.7.4