From: Johannes M Dieterich Date: Thu, 20 Dec 2018 22:26:14 +0000 (-0800) Subject: Add launch bounds needed for ROCm 2.0 (#15400) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2140 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c24a124fa0df4199d9180cc009c5b5e8ff88bba3;p=platform%2Fupstream%2Fpytorch.git Add launch bounds needed for ROCm 2.0 (#15400) Summary: ROCm 2.0's compiler requires launch_bounds annotations if flat work group sizes are larger than the default of 256. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15400 Differential Revision: D13531239 Pulled By: ezyang fbshipit-source-id: c0b40600a8c332823da6c7113c644d8dba424a9c --- diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index 722af85..269aa40 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -44,7 +44,11 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t // In contrast to the cuDNN implementation, we allow large target lengths. For this we need that all previous `s` have been // computed when we start a new block_s. This is why we have our own for loop here. template -__global__ void ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, +__global__ void +#if defined (__HIP_PLATFORM_HCC__) +__launch_bounds__((std::is_same::value ? 1024 : 896), 1) +#endif +ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, scalar_t* __restrict__ neg_log_likelihood_data, @@ -359,7 +363,11 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, // calculation (with an atomic log add) is similarly in performance, but for large // alphabets the inplace nature is a considerable advantage. template -__global__ void ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data, +__global__ void +#if defined (__HIP_PLATFORM_HCC__) +__launch_bounds__((std::is_same::value ? 1024 : 896), 1) +#endif +ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data, const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, @@ -404,7 +412,11 @@ __global__ void ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restri // This is the naive implementation of equation (16). It is parallelised in batch and input timestep. // It appears to be faster than the above method for small batch sizes. template -__global__ void ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, +__global__ void +#if defined (__HIP_PLATFORM_HCC__) +__launch_bounds__((std::is_same::value ? 1024 : 896), 1) +#endif +ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, diff --git a/aten/src/THC/THCReduceAll.cuh b/aten/src/THC/THCReduceAll.cuh index af80454..cec7f4c 100644 --- a/aten/src/THC/THCReduceAll.cuh +++ b/aten/src/THC/THCReduceAll.cuh @@ -25,6 +25,9 @@ template __global__ void +#if defined(__HIP_PLATFORM_HCC__) +__launch_bounds__(THC_REDUCE_ALL_BLOCK_SIZE) +#endif kernelReduceAll(TensorInfo in, IndexType totalElements, AccT init, diff --git a/aten/src/THCUNN/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/MultiLabelMarginCriterion.cu index 3df336a..602daf7 100644 --- a/aten/src/THCUNN/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/MultiLabelMarginCriterion.cu @@ -10,6 +10,9 @@ #define MULTILABELMARGIN_THREADS 1024 template +#if defined(__HIP_PLATFORM_HCC__) +__launch_bounds__(MULTILABELMARGIN_THREADS) +#endif __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output, Dtype *input, THCIndex_t *target, @@ -77,6 +80,9 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output } template +#if defined(__HIP_PLATFORM_HCC__) +__launch_bounds__(MULTILABELMARGIN_THREADS) +#endif __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput, Dtype *gradOutput, Dtype *input, diff --git a/aten/src/THCUNN/SpatialClassNLLCriterion.cu b/aten/src/THCUNN/SpatialClassNLLCriterion.cu index c012826..6e758bf 100644 --- a/aten/src/THCUNN/SpatialClassNLLCriterion.cu +++ b/aten/src/THCUNN/SpatialClassNLLCriterion.cu @@ -67,6 +67,9 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel( } template +#if defined(__HIP_PLATFORM_HCC__) +__launch_bounds__(1024) +#endif __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel( T *output, T *total_weight,