Added launch bounds in VolumetricConvolution.cu (#14564)
authorMichael Carilli <mcarilli@nvidia.com>
Thu, 29 Nov 2018 22:47:32 +0000 (14:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 29 Nov 2018 22:49:29 +0000 (14:49 -0800)
Summary:
A few months ago we were seeing test failures on certain architectures due to invalid launch configurations of the kernels in aten/src/THCUNN/VolumetricConvolution.cu.

This PR ensures that those kernels are always compiled such that at least one block can be resident on an SM, and such errors will not be encountered at runtime on any architecture after compiling for that architecture.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14564

Differential Revision: D13266136

Pulled By: soumith

fbshipit-source-id: 35464b20848bb0a1168e8f3b233172331c50b35b

aten/src/THCUNN/VolumetricConvolution.cu

index 2e405f3..4c3a771 100644 (file)
@@ -8,13 +8,14 @@
 // Borrowed from Theano
 // Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
 template <typename Dtype>
-__global__ void im3d2col_kernel(const int64_t n, const Dtype* data_im,
-                                const int64_t height, const int64_t width, const int64_t depth,
-                                const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
-                                const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
-                                const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
-                                const int64_t height_col, const int64_t width_col, const int64_t depth_col,
-                                Dtype* data_col)
+__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+im3d2col_kernel(const int64_t n, const Dtype* data_im,
+                const int64_t height, const int64_t width, const int64_t depth,
+                const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
+                const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
+                const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
+                const int64_t height_col, const int64_t width_col, const int64_t depth_col,
+                Dtype* data_col)
 {
   CUDA_KERNEL_LOOP(index, n)
   {
@@ -86,14 +87,15 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
 }
 
 template <typename Dtype, typename Acctype>
-__global__ void col2im3d_kernel(const int64_t n, const Dtype* data_col,
-                                const int64_t height, const int64_t width, const int64_t depth,
-                                const int64_t channels,
-                                const int64_t patch_h, const int64_t patch_w, const int64_t patch_d,
-                                const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
-                                const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
-                                const int64_t height_col, const int64_t width_col, const int64_t depth_col,
-                                Dtype* data_im)
+__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+col2im3d_kernel(const int64_t n, const Dtype* data_col,
+                const int64_t height, const int64_t width, const int64_t depth,
+                const int64_t channels,
+                const int64_t patch_h, const int64_t patch_w, const int64_t patch_d,
+                const int64_t pad_h, const int64_t pad_w, const int64_t pad_d,
+                const int64_t stride_h, const int64_t stride_w, const int64_t stride_d,
+                const int64_t height_col, const int64_t width_col, const int64_t depth_col,
+                Dtype* data_im)
 {
   CUDA_KERNEL_LOOP(index, n)
   {