}
};
+template <typename T, typename IndT, bool handleNaN = true>
+struct ThrustSliceGTOp {
+ThrustSliceGTOp(int64_t size) : sliceSize(size) {}
+ __device__ bool operator()(const thrust::tuple<int64_t, T>& lhs, const thrust::tuple<int64_t, T>& rhs) const {
+ IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize;
+ IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize;
+ if (segA != segB)
+ return segA < segB;
+ else
+ return (handleNaN && THCNumerics<T>::isnan(thrust::get<1>(lhs)) && !THCNumerics<T>::isnan(thrust::get<1>(rhs))) || THCNumerics<T>::gt(thrust::get<1>(lhs), thrust::get<1>(rhs));
+ }
+ const IndT sliceSize;
+};
+
+template <typename T, typename IndT, bool handleNaN = true>
+struct ThrustSliceLTOp {
+ThrustSliceLTOp(int64_t size) : sliceSize(size) {}
+ __device__ bool operator()(const thrust::tuple<int64_t, T>& lhs, const thrust::tuple<int64_t, T>& rhs) const {
+ IndT segA = (IndT)thrust::get<0>(lhs) / sliceSize;
+ IndT segB = (IndT)thrust::get<0>(rhs) / sliceSize;
+ if (segA != segB)
+ return segA < segB;
+ else
+ return (handleNaN && THCNumerics<T>::isnan(thrust::get<1>(rhs)) && !THCNumerics<T>::isnan(thrust::get<1>(lhs))) || THCNumerics<T>::lt(thrust::get<1>(lhs), thrust::get<1>(rhs));
+ }
+ const IndT sliceSize;
+};
+
+
+
+
// `base` is the base address of a tensor
// For each slice (defined as a linear point of `out`, from 0 ->
// (sliceSize - 1) * sliceStride, we fill that slice from `0` to
}
}
-// For slice sorting in Thrust; extracts a slice index from a linear
-// index and uses that for comparison
-struct SliceComp {
- SliceComp(int64_t size) : sliceSize(size) {}
-
- __device__ bool operator()(const int64_t& a, const int64_t& b) const {
- // Since the slices are guaranteed to be innermost,
- // the segment is just via int64_t division
- int64_t segA = a / sliceSize;
- int64_t segB = b / sliceSize;
- return segA < segB;
- }
-
- const int64_t sliceSize;
-};
-
// For sorting in Thurst; extracts a within-slice index from a linear index
struct GlobalIndexToPerSliceIndex {
GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
// Fill the indices with a global index across all slices
thrust::counting_iterator<int64_t> countIter(0);
-
thrust::copy(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
countIter, countIter + totalElements, indexIter);
-
- // First, we sort globally (across all slices) according to key
- // (the values we're sorting)
- if (dir) {
- thrust::stable_sort_by_key(
+ auto begin = thrust::make_zip_iterator(thrust::make_tuple(indexIter, keyIter));
+ if (dir){
+ if (totalElements < INT_MAX)
+ thrust::sort(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
- thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t, true>());
+ begin, begin + totalElements, ThrustSliceGTOp<scalar_t, int, true>(sliceSize));
+ else
+ thrust::sort(
+#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+#endif
+ begin, begin + totalElements, ThrustSliceGTOp<scalar_t, int64_t, true>(sliceSize));
} else {
- thrust::stable_sort_by_key(
+ if (totalElements < INT_MAX)
+ thrust::sort(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
- thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t, true>());
- }
-
- // Then, re-sort according to slice that each index is
- // in. This completes the segment sort in Thrust, since we're
- // stably sorting here, preserving the relative order of values
- // per each slice
- thrust::stable_sort_by_key(
+ begin, begin + totalElements, ThrustSliceLTOp<scalar_t, int, true>(sliceSize));
+ else
+ thrust::sort(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
- thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
+ thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- indexIter, indexIter + totalElements, keyIter,
- SliceComp(sliceSize));
-
+ begin, begin + totalElements, ThrustSliceLTOp<scalar_t, int64_t, true>(sliceSize));
+ }
// Translate the global integer 0-based index to a per-slice real
// Lua index
thrust::for_each(
THCTensor_(transpose)(state, trContigKey, NULL, dim, nDims - 1);
THCudaLongTensor_transpose(state, trContigIndices, NULL, dim, nDims - 1);
}
-
// Then copy back to the expected output
THCTensor_(freeCopyTo)(state, trContigKey, sorted);
THCudaLongTensor_freeCopyTo(state, trContigIndices, indices);