From 237e5438f532418ecb465ca7209e3e2241e32501 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 20 Feb 2019 03:06:53 -0800 Subject: [PATCH] Add launch bounds for TopK kernel, be more conservative in sorting (#17296) Summary: The particular use case reported is Jetson TX2 and maskrcnn. Fixes #17144 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17296 Differential Revision: D14147886 Pulled By: soumith fbshipit-source-id: 44d5a89aaeb4cc07d1b53dd90121013be93c419c --- aten/src/THC/THCTensorTopK.cuh | 3 +++ aten/src/THC/generic/THCTensorTopK.cu | 15 ++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/aten/src/THC/THCTensorTopK.cuh b/aten/src/THC/THCTensorTopK.cuh index 773232d..edf1bdf 100644 --- a/aten/src/THC/THCTensorTopK.cuh +++ b/aten/src/THC/THCTensorTopK.cuh @@ -1,6 +1,8 @@ #ifndef THC_TENSOR_TOPK_CUH #define THC_TENSOR_TOPK_CUH +#include + template struct TopKTypeConfig {}; @@ -359,6 +361,7 @@ __device__ void radixSelect(DataType* data, } template +C10_LAUNCH_BOUNDS(1024) __global__ void gatherTopK(TensorInfo input, IndexType inputSliceSize, IndexType outputSliceSize, // aka `k` diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index d52ba7e..902b2b5 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -140,7 +140,20 @@ void THCTensor_(topk)(THCState* state, if (sorted) { // FIXME: the k/v inplace sort along slice only works for size <= // 2048 at the moment - if (sliceSize <= 2048) { + // Workaround: + // CUDA 8 uses more shared memory than 7.5 for bitonicSortKVInPlace, + // and so for the double word types, + // we get "too many resources requested for launch" in the 2048 case +#if CUDA_VERSION >= 8000 +#if defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_LONG) + int maxSliceSize = 1024; +#else + int maxSliceSize = 2048; +#endif +#else + int maxSliceSize = 2048; +#endif + if (sliceSize <= maxSliceSize) { // This avoids any memory allocations and performs all sorting // work inplace along the slice THCTensor_(sortKeyValueInplace)(state, topK, indices, dim, dir); -- 2.7.4