From bd368b867dc09461bbc0710322b24ef10957af1a Mon Sep 17 00:00:00 2001 From: Johannes M Dieterich Date: Fri, 14 Dec 2018 14:45:11 -0800 Subject: [PATCH] Do not ifdef __launch_bounds__ out for ROCm. (#15228) Summary: The compiler understands it and profits from knowing it by not using too many VGPRs as it defaults to 256 default workgroup size. Fixes a problem in bringup of ROCm 2.0 on gfx906. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15228 Differential Revision: D13470950 Pulled By: bddppq fbshipit-source-id: f9aa44c7c95299a099c0ea9317b9044cc056acc5 --- aten/src/ATen/cuda/CUDAApplyUtils.cuh | 8 ++++---- aten/src/ATen/native/cuda/Dropout.cu | 2 +- aten/src/ATen/native/cuda/RNN.cu | 8 ++++---- aten/src/ATen/native/cuda/TensorTransformations.cu | 2 +- aten/src/THC/THCReduce.cuh | 4 ++-- aten/src/THCUNN/SpatialCrossMapLRN.cu | 2 +- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh index 2f13bd6..45317fb 100644 --- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh +++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh @@ -271,7 +271,7 @@ template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) #endif __global__ void kernelPointwiseApply1(detail::TensorInfo a, @@ -355,7 +355,7 @@ template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) #endif __global__ void @@ -464,7 +464,7 @@ template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) #endif __global__ void @@ -587,7 +587,7 @@ template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) #endif __global__ void diff --git a/aten/src/ATen/native/cuda/Dropout.cu b/aten/src/ATen/native/cuda/Dropout.cu index a3d51f0..5804287 100644 --- a/aten/src/ATen/native/cuda/Dropout.cu +++ b/aten/src/ATen/native/cuda/Dropout.cu @@ -33,7 +33,7 @@ template < typename accscalar_t, typename IndexType, int ADims> -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(256,8) #endif __global__ void diff --git a/aten/src/ATen/native/cuda/RNN.cu b/aten/src/ATen/native/cuda/RNN.cu index d080125..0d4cee7 100644 --- a/aten/src/ATen/native/cuda/RNN.cu +++ b/aten/src/ATen/native/cuda/RNN.cu @@ -80,7 +80,7 @@ T sigmoid(T in) { namespace kernel { template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void lstm_cell_forward( @@ -167,7 +167,7 @@ __global__ void lstm_cell_forward( } template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void lstm_cell_backward( @@ -232,7 +232,7 @@ __global__ void lstm_cell_backward( } template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void gru_cell_forward( @@ -302,7 +302,7 @@ __global__ void gru_cell_forward( } template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void gru_cell_backward( diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index a30bd8a..80f64a9 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -15,7 +15,7 @@ namespace native { #define AT_APPLY_BLOCKS_PER_SM 4 template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM) #endif __global__ void diff --git a/aten/src/THC/THCReduce.cuh b/aten/src/THC/THCReduce.cuh index 61e4e9a..6fad806 100644 --- a/aten/src/THC/THCReduce.cuh +++ b/aten/src/THC/THCReduce.cuh @@ -139,7 +139,7 @@ template typename ReduceOp, typename FinalizeOp, int ADims, int BDims> -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void kernelReduceNoncontigDim_shared @@ -254,7 +254,7 @@ template -#if __CUDA_ARCH__ >= 350 +#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(32 * 16, 4) #endif __global__ void diff --git a/aten/src/THCUNN/SpatialCrossMapLRN.cu b/aten/src/THCUNN/SpatialCrossMapLRN.cu index 8545261..b5586f5 100644 --- a/aten/src/THCUNN/SpatialCrossMapLRN.cu +++ b/aten/src/THCUNN/SpatialCrossMapLRN.cu @@ -7,7 +7,7 @@ template __global__ void -#if __CUDA_ARCH__ >= 320 +#if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__ __launch_bounds__(CUDA_NUM_THREADS) #endif LRNFillScale(const int nthreads, const Dtype* const in, -- 2.7.4