Resolves ptxas warnings when compiling for CUDA_ARCH 750 and a memoryType deprecation...
authorSyed Tousif Ahmed <syed.ahmed.emails@gmail.com>
Fri, 11 Jan 2019 05:41:48 +0000 (21:41 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 05:44:39 +0000 (21:44 -0800)
Summary:
When compiling for `TORCH_CUDA_ARCH_LIST=7.5` we were getting ptxas warnings (https://github.com/pytorch/pytorch/issues/14310). This was because we had some hardcoded values when using launch_bounds in kernels. The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5) but 2048 for previous architectures. The hardcoded launch_bounds in the kernel were requesting for 2048 threads when compiling for Turing and hence were generating the warning.

This PR adds a macro that checks for the bounds on the launch bounds value supplied. The max number of threads per block across all architectures is 1024. If a user supplies more than 1024, I just clamp it down to 512. Depending on this value, I set the minimum number of blocks per sm. This PR should resolve https://github.com/pytorch/pytorch/issues/14310. The gradient computation being wrong reported in that PR is probably due to the faulty card.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15461

Differential Revision: D13633952

Pulled By: soumith

fbshipit-source-id: 795aa151109f343ab5433bf3cb070cb6ec896fff

20 files changed:
aten/src/ATen/cuda/CUDAApplyUtils.cuh
aten/src/ATen/native/cuda/Dropout.cu
aten/src/ATen/native/cuda/GridSampler.cu
aten/src/ATen/native/cuda/Loops.cuh
aten/src/ATen/native/cuda/LossCTC.cu
aten/src/ATen/native/cuda/RNN.cu
aten/src/ATen/native/cuda/Reduce.cuh
aten/src/ATen/native/cuda/TensorTransformations.cu
aten/src/THC/THCReduce.cuh
aten/src/THC/THCReduceAll.cuh
aten/src/THC/THCSortUtils.cuh
aten/src/THCUNN/MultiLabelMarginCriterion.cu
aten/src/THCUNN/SpatialClassNLLCriterion.cu
aten/src/THCUNN/SpatialCrossMapLRN.cu
aten/src/THCUNN/SpatialDilatedMaxPooling.cu
aten/src/THCUNN/VolumetricConvolution.cu
aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu
aten/src/THCUNN/im2col.h
c10/macros/Macros.h
torch/csrc/generic/StorageMethods.cpp

index 45317fb..82a3eb0 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/TensorUtils.h>
 #include <THC/THCAtomics.cuh>
 #include <ATen/cuda/CUDAContext.h>
+#include <c10/macros/Macros.h>
 
 #include <math.h>
 
@@ -198,8 +199,8 @@ inline void rearrangeDims(detail::TensorInfo<T1, IndexType>* aInfo,
 
 // Threads per block for our apply kernel
 // FIXME: use occupancy calculator instead
-#define AT_APPLY_THREADS_PER_BLOCK 32 * 16
-#define AT_APPLY_BLOCKS_PER_SM 4
+constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
+constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
 
 // The `remaining_steps` argument is used to support Op that operates on
 // multiple elements at the same time. Generally, the strategy of ApplyOpN is to
@@ -272,7 +273,7 @@ template <typename Op,
           int ADims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
                                       IndexType totalElements, const Op op) {
@@ -356,7 +357,7 @@ template <typename Op,
           int ADims, int BDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
@@ -465,7 +466,7 @@ template <typename Op,
           int ADims, int BDims, int CDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply3(detail::TensorInfo<scalar1, IndexType> a,
@@ -588,7 +589,7 @@ template <typename Op,
           int ADims, int BDims, int CDims, int DDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply4(detail::TensorInfo<scalar1, IndexType> a,
index 5804287..186aeed 100644 (file)
@@ -3,6 +3,7 @@
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <ATen/cuda/detail/IndexUtils.cuh>
 #include <ATen/cuda/detail/TensorInfo.cuh>
+#include <c10/macros/Macros.h>
 #include <curand_kernel.h>
 
 #include <THC/THCGeneral.h>
@@ -34,7 +35,7 @@ template <
           typename IndexType,
           int ADims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(256,8)
+C10_LAUNCH_BOUNDS(256, 8)
 #endif
 __global__ void
 fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
index be0b2f6..8dba784 100644 (file)
@@ -5,6 +5,7 @@
 #include <ATen/cuda/detail/TensorInfo.cuh>
 #include <ATen/cuda/detail/IndexUtils.cuh>
 #include <ATen/cuda/detail/KernelUtils.h>
+#include <c10/macros/Macros.h>
 
 namespace at { namespace native {
 
@@ -119,7 +120,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  __launch_bounds__(1024)
+  C10_LAUNCH_BOUNDS(1024)
   __global__ void grid_sampler_2d_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> input,
@@ -227,7 +228,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  __launch_bounds__(1024)
+  C10_LAUNCH_BOUNDS(1024)
   __global__ void grid_sampler_3d_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> input,
@@ -391,7 +392,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  __launch_bounds__(1024)
+  C10_LAUNCH_BOUNDS(1024)
   __global__ void grid_sampler_2d_backward_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> grad_output,
@@ -546,7 +547,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  __launch_bounds__(1024)
+  C10_LAUNCH_BOUNDS(1024)
   __global__ void grid_sampler_3d_backward_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> grad_output,
index 461ed9a..71bb4bc 100644 (file)
@@ -5,7 +5,7 @@
 #include <ATen/cuda/detail/OffsetCalculator.cuh>
 #include <ATen/detail/FunctionTraits.h>
 #include <ATen/native/TensorIterator.h>
-
+#include <c10/macros/Macros.h>
 
 // Marks a lambda as executable on both the host and device. The __host__
 // attribute is important so that we can access static type information from
@@ -26,7 +26,7 @@
 namespace at { namespace native {
 
 template<int nt, int vt, typename func_t>
-__launch_bounds__(nt, 4)
+C10_LAUNCH_BOUNDS(nt, 4)
 __global__ void elementwise_kernel(int N, func_t f) {
   int tid = threadIdx.x;
   int nv = nt * vt;
index 053fe7b..7759b6c 100644 (file)
@@ -10,6 +10,7 @@
 
 #include <ATen/TensorUtils.h>
 #include <c10/util/Exception.h>
+#include <c10/macros/Macros.h>
 
 #include <ATen/ATen.h>
 #include <ATen/Dispatch.h>
@@ -46,7 +47,7 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
                                     const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
@@ -259,7 +260,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
 // alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
 template<typename scalar_t, typename target_t>
 __global__ void
-__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
                                       const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
                                       const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
@@ -365,7 +366,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
@@ -414,7 +415,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-__launch_bounds__((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
index 0d4cee7..8649ee6 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <c10/macros/Macros.h>
 
 namespace at { namespace native {
 
@@ -81,7 +82,7 @@ namespace kernel {
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void lstm_cell_forward(
             TensorInfo<scalar_t, index_type> input,
@@ -168,7 +169,7 @@ __global__ void lstm_cell_forward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void lstm_cell_backward(
               TensorInfo<scalar_t, index_type> storage,
@@ -233,7 +234,7 @@ __global__ void lstm_cell_backward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void gru_cell_forward(
             TensorInfo<scalar_t, index_type> Input,
@@ -303,7 +304,7 @@ __global__ void gru_cell_forward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void gru_cell_backward(
              TensorInfo<scalar_t, index_type> gradInInput,
index db26a17..6409b10 100644 (file)
@@ -10,6 +10,7 @@
 #include <THC/THCGeneral.hpp>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cuda/Loops.cuh>
+#include <c10/macros/Macros.h>
 #include <functional>
 #include <iosfwd>
 #include <tuple>
@@ -146,7 +147,7 @@ struct ReduceConfig {
 std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
 
 template<int nt, typename R>
-__launch_bounds__(nt, 4)
+C10_LAUNCH_BOUNDS(nt, 4)
 __global__ void reduce_kernel(R reduction) {
   reduction.run();
 }
index 80f64a9..b72ca43 100644 (file)
@@ -4,6 +4,7 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <ATen/cuda/CUDAContext.h>
+#include <c10/macros/Macros.h>
 
 #include <cstddef>
 #include <vector>
 namespace at {
 namespace native {
 
-#define AT_APPLY_THREADS_PER_BLOCK 32 * 16
-#define AT_APPLY_BLOCKS_PER_SM 4
+constexpr uint32_t AT_APPLY_THREADS_PER_BLOCK = 512;
+constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
 
 template <typename scalar_t, typename IndexType>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernel_pointwise_flip_apply2(const cuda::detail::TensorInfo<scalar_t, IndexType> in_tensor_info,
index 6fad806..f750b53 100644 (file)
@@ -11,6 +11,7 @@
 #include <THC/THCTensorTypeUtils.cuh>
 #include <THC/THCReduceApplyUtils.cuh>
 #include <THC/THCNumerics.cuh>
+#include <c10/macros/Macros.h>
 
 // Threads per thread block
 #define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16
@@ -140,7 +141,7 @@ template
    typename FinalizeOp,
    int ADims, int BDims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void kernelReduceNoncontigDim_shared
   (TensorInfo<T, IndexType> out,
@@ -255,7 +256,7 @@ template <typename T,
           typename FinalizeOp,
           int ADims, int BDims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(32 * 16, 4)
+C10_LAUNCH_BOUNDS(512, 4)
 #endif
 __global__ void
 kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
index cec7f4c..65c1b53 100644 (file)
@@ -10,6 +10,7 @@
 //
 
 #include <THC/THCReduceApplyUtils.cuh>
+#include <c10/macros/Macros.h>
 
 // Size per each reduction block
 #define THC_REDUCE_ALL_BLOCK_SIZE 1024L
@@ -26,7 +27,7 @@ template <typename T,
           int ADims>
 __global__ void
 #if defined(__HIP_PLATFORM_HCC__)
-__launch_bounds__(THC_REDUCE_ALL_BLOCK_SIZE)
+C10_LAUNCH_BOUNDS(THC_REDUCE_ALL_BLOCK_SIZE)
 #endif
 kernelReduceAll(TensorInfo<T, IndexType> in,
                 IndexType totalElements,
index d25ac6a..038b413 100644 (file)
@@ -4,6 +4,7 @@
 #include <THC/THCReduceApplyUtils.cuh>
 #include <THC/THCTensorTypeUtils.cuh>
 #include <THC/THCNumerics.cuh>
+#include <c10/macros/Macros.h>
 
 // Collection of kernel sort routines
 template <typename T>
@@ -134,7 +135,7 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
 template <typename K, typename V,
           int KeyDims, int ValueDims,
           typename Comparator, typename IndexType, int Power2SortSize>
-__launch_bounds__(1024)
+C10_LAUNCH_BOUNDS(1024)
 __global__ void
 bitonicSortKVInPlace(TensorInfo<K, IndexType> keys,
                      IndexType keySlices,
index 602daf7..0257ed9 100644 (file)
@@ -4,6 +4,7 @@
 #include <THC/THCReduceApplyUtils.cuh>
 #include <TH/THHalf.h>
 #include <THCUNN/THCHalfAutoNumerics.cuh>
+#include <c10/macros/Macros.h>
 
 #include <thrust/functional.h>
 
@@ -11,7 +12,7 @@
 
 template <typename Dtype, typename Acctype>
 #if defined(__HIP_PLATFORM_HCC__)
-__launch_bounds__(MULTILABELMARGIN_THREADS)
+C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
 #endif
 __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
                                                                    Dtype *input,
@@ -81,7 +82,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output
 
 template <typename Dtype, typename Acctype>
 #if defined(__HIP_PLATFORM_HCC__)
-__launch_bounds__(MULTILABELMARGIN_THREADS)
+C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
 #endif
 __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
                                                                       Dtype *gradOutput,
index 6e758bf..f50c5c1 100644 (file)
@@ -7,6 +7,7 @@
 #include <THC/THCDeviceTensorUtils.cuh>
 #include <THC/THCDeviceUtils.cuh>
 #include <THC/THCApply.cuh>
+#include <c10/macros/Macros.h>
 
 #include <thrust/functional.h>
 
@@ -68,7 +69,7 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(
 
 template <typename T, typename AccumT>
 #if defined(__HIP_PLATFORM_HCC__)
-__launch_bounds__(1024)
+C10_LAUNCH_BOUNDS(1024)
 #endif
 __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
           T *output,
index b5586f5..e1fbd33 100644 (file)
@@ -4,11 +4,12 @@
 #include <THC/THCTensor.hpp>
 #include <THC/THCStorage.hpp>
 #include <THCUNN/common.h>
+#include <c10/macros/Macros.h>
 
 template <typename Dtype, typename Acctype>
 __global__ void
 #if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__
-__launch_bounds__(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
 #endif
 LRNFillScale(const int nthreads, const Dtype* const in,
     const int num, const int channels, const int height,
index 0dc32b6..d0213b1 100644 (file)
@@ -4,6 +4,7 @@
 #include <THCUNN/THCHalfAutoNumerics.cuh>
 #include <THC/THCNumerics.cuh>
 #include <THCUNN/common.h>
+#include <c10/macros/Macros.h>
 
 // kernels borrowed from Caffe
 template <typename Dtype, typename AccType>
@@ -47,7 +48,7 @@ __global__ void MaxPoolForward(const int nthreads, const Dtype* bottom_data,
 const int BACKWARD_THREADS = 256;
 
 template <typename Dtype, typename AccType>
-__launch_bounds__(BACKWARD_THREADS,2048/BACKWARD_THREADS)
+C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 8)
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     const int64_t* top_mask, const int num, const int channels,
     const int height, const int width, const int pooled_height,
index e22520c..3c0b0ca 100644 (file)
@@ -3,12 +3,13 @@
 #include <THCUNN/common.h>
 #include <TH/THHalf.h>
 #include <THCUNN/THCHalfAutoNumerics.cuh>
+#include <c10/macros/Macros.h>
 
 // Kernel for fast unfold+copy
 // Borrowed from Theano
 // Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
 template <typename Dtype>
-__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
 im3d2col_kernel(const int64_t n, const Dtype* data_im,
                 const int64_t height, const int64_t width, const int64_t depth,
                 const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
@@ -87,7 +88,7 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
 }
 
 template <typename Dtype, typename Acctype>
-__global__ void __launch_bounds__(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
 col2im3d_kernel(const int64_t n, const Dtype* data_col,
                 const int64_t height, const int64_t width, const int64_t depth,
                 const int64_t channels,
index da159bd..7e1ab11 100644 (file)
 #include <TH/THHalf.h>
 #include <THCUNN/THCHalfAutoNumerics.cuh>
 #include <THC/THCAtomics.cuh>
+#include <c10/macros/Macros.h>
 
 template<typename Dtype, typename Acctype>
-__launch_bounds__(1024)
+C10_LAUNCH_BOUNDS(1024)
 __global__ void caffe_gpu_interp2_kernel(const int n,
     const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
     const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
@@ -80,7 +81,7 @@ __global__ void caffe_gpu_interp2_kernel(const int n,
 
 // Backward (adjoint) operation 1 <- 2 (accumulates)
 template <typename Dtype, typename Acctype>
-__launch_bounds__(1024)
+C10_LAUNCH_BOUNDS(1024)
 __global__ void caffe_gpu_interp2_kernel_backward(const int n,
     const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
     THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
index 55c97d7..22a4c7a 100644 (file)
@@ -3,11 +3,12 @@
 
 #include <THCUNN/common.h>
 #include <THC/THCNumerics.cuh>
+#include <c10/macros/Macros.h>
 
 // Kernel for fast unfold+copy
 // (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
 template <typename Dtype>
-__launch_bounds__(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
 __global__ void im2col_kernel(const int64_t n, const Dtype* data_im,
                               const int64_t height, const int64_t width,
                               const int64_t ksize_h, const int64_t ksize_w,
@@ -59,7 +60,7 @@ void im2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
 }
 
 template <typename Dtype, typename Acctype>
-__launch_bounds__(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
 __global__ void col2im_kernel(const int64_t n, const Dtype* data_col,
                                   const int64_t height, const int64_t width, const int64_t channels,
                                   const int64_t kernel_h, const int64_t kernel_w,
index 29a02df..8834f83 100644 (file)
@@ -109,6 +109,39 @@ namespace at { namespace cuda { using namespace c10::hip; }}
 #define C10_HOST_DEVICE __host__ __device__
 #define C10_DEVICE __device__
 #define C10_HOST __host__
+// constants from (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications)
+// The maximum number of threads per multiprocessor is 1024 for Turing architecture (7.5) 
+// but 2048 for previous architectures. You'll get warnings if you exceed these constants. 
+// Hence, the following macros adjust the input values from the user to resolve potential warnings.
+#if __CUDA_ARCH__ >= 750
+constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024;
+#else
+constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048;
+#endif
+// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently
+constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024;
+// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block size.
+// 256 is a good number for this fallback and should give good occupancy and 
+// versatility across all architectures.
+constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
+// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it
+//       turns out that although __launch_bounds__ can take constexpr, it 
+//       can't take a constexpr that has anything to do with templates. 
+//       Currently we use launch_bounds that depend on template arguments in 
+//       Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK and 
+//       C10_MIN_BLOCKS_PER_SM are kept as macros.
+// Suppose you were planning to write __launch_bounds__(a, b), based on your performance tuning on a modern GPU. 
+// Instead, you should write __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), 
+// which will also properly respect limits on old architectures.
+#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
+#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
+// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
+// https://stackoverflow.com/a/8814003 snippet to have macro with an optional argument
+#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
+#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
+#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
+#define C10_LAUNCH_BOUNDS_X(x,max_threads_per_block,min_blocks_per_sm,FUNC, ...) FUNC
+#define C10_LAUNCH_BOUNDS(...) C10_LAUNCH_BOUNDS_X(,##__VA_ARGS__, C10_LAUNCH_BOUNDS_2(__VA_ARGS__), C10_LAUNCH_BOUNDS_1(__VA_ARGS__), C10_LAUNCH_BOUNDS_0(__VA_ARGS__))
 #else
 #define C10_HOST_DEVICE
 #define C10_HOST
index 42471b5..fb9247e 100644 (file)
@@ -36,7 +36,11 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self)
     cudaGetLastError();
     Py_RETURN_FALSE;
   }
-  return PyBool_FromLong(attr.memoryType == cudaMemoryTypeHost);
+  #if CUDA_VERSION >= 10000
+    return PyBool_FromLong(attr.type == cudaMemoryTypeHost);
+  #else
+    return PyBool_FromLong(attr.memoryType == cudaMemoryTypeHost);
+  #endif
 #else
   Py_RETURN_FALSE;
 #endif