Add launch bounds needed for ROCm 2.0 (#15400)
authorJohannes M Dieterich <johannes.dieterich@amd.com>
Thu, 20 Dec 2018 22:26:14 +0000 (14:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 22:39:13 +0000 (14:39 -0800)
Summary:
ROCm 2.0's compiler requires launch_bounds annotations if flat work group sizes are larger than the default of 256.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15400

Differential Revision: D13531239

Pulled By: ezyang

fbshipit-source-id: c0b40600a8c332823da6c7113c644d8dba424a9c

aten/src/ATen/native/cuda/LossCTC.cu
aten/src/THC/THCReduceAll.cuh
aten/src/THCUNN/MultiLabelMarginCriterion.cu
aten/src/THCUNN/SpatialClassNLLCriterion.cu

index 722af85..269aa40 100644 (file)
@@ -44,7 +44,11 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t
 // In contrast to the cuDNN implementation, we allow large target lengths. For this we need that all previous `s` have been
 // computed when we start a new block_s. This is why we have our own for loop here.
 template<typename scalar_t, typename target_t>
-__global__ void ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
+__global__ void
+#if defined (__HIP_PLATFORM_HCC__)
+__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+#endif
+ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
                                     const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
                                     const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
                                     scalar_t* __restrict__ neg_log_likelihood_data,
@@ -359,7 +363,11 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
 // calculation (with an atomic log add) is similarly in performance, but for large
 // alphabets the inplace nature is a considerable advantage.
 template<typename scalar_t, typename target_t>
-__global__ void ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
+__global__ void
+#if defined (__HIP_PLATFORM_HCC__)
+__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+#endif
+ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
                                                      const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
                                                      const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
@@ -404,7 +412,11 @@ __global__ void ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restri
 // This is the naive implementation of equation (16). It is parallelised in batch and input timestep.
 // It appears to be faster than the above method for small batch sizes.
 template<typename scalar_t, typename target_t>
-__global__ void ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
+__global__ void
+#if defined (__HIP_PLATFORM_HCC__)
+__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+#endif
+ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
                                                      const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
                                                      const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
index af80454..cec7f4c 100644 (file)
@@ -25,6 +25,9 @@ template <typename T,
           typename ReduceOp,
           int ADims>
 __global__ void
+#if defined(__HIP_PLATFORM_HCC__)
+__launch_bounds__(THC_REDUCE_ALL_BLOCK_SIZE)
+#endif
 kernelReduceAll(TensorInfo<T, IndexType> in,
                 IndexType totalElements,
                 AccT init,
index 3df336a..602daf7 100644 (file)
@@ -10,6 +10,9 @@
 #define MULTILABELMARGIN_THREADS 1024
 
 template <typename Dtype, typename Acctype>
+#if defined(__HIP_PLATFORM_HCC__)
+__launch_bounds__(MULTILABELMARGIN_THREADS)
+#endif
 __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
                                                                    Dtype *input,
                                                                    THCIndex_t *target,
@@ -77,6 +80,9 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output
 }
 
 template <typename Dtype, typename Acctype>
+#if defined(__HIP_PLATFORM_HCC__)
+__launch_bounds__(MULTILABELMARGIN_THREADS)
+#endif
 __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
                                                                       Dtype *gradOutput,
                                                                       Dtype *input,
index c012826..6e758bf 100644 (file)
@@ -67,6 +67,9 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(
 }
 
 template <typename T, typename AccumT>
+#if defined(__HIP_PLATFORM_HCC__)
+__launch_bounds__(1024)
+#endif
 __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
           T *output,
           T *total_weight,