msvc_fixes (#17201)
authorGrigory Arutyunov <arutyunovg@yandex.ru>
Fri, 1 Mar 2019 23:07:18 +0000 (15:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 1 Mar 2019 23:17:41 +0000 (15:17 -0800)
Summary:
Fixing MSVC errors

```
  D:\pytorch-scripts\caffe2_builders\v141\pytorch\aten\src\THC/THCReduce.cuh(144): error C4002: too many actual paramet
ers for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caffe2_gpu.vcxp
roj]
  D:\pytorch-scripts\caffe2_builders\v141\pytorch\aten\src\THC/THCReduce.cuh(259): error C4002: too many actual paramet
ers for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caffe2_gpu.vcxp
roj]
  D:/pytorch-scripts/caffe2_builders/v141/pytorch/aten/src/THCUNN/SpatialDilatedMaxPooling.cu(51): error C4002: too man
y actual parameters for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2
\caffe2_gpu.vcxproj]
```

on variadic C10_LAUNCH_BOUNDS as well as Debug linking issues with at::Half in pool_op_cudnn.cc like this one

```
pool_op_cudnn.obj : error LNK2019: unresolved external symbol "public: bool __cdecl caffe2::MaxPoolFunctor<class caff
e2::CUDAContext>::GlobalPoolingBackward<struct c10::Half,2>(int,int,int,struct c10::Half const *,struct c10::Half const
 ,struct c10::Half const ,struct c10::Half ,class caffe2::CUDAContext )const " (??$GlobalPoolingBackward@UHalf@c10@
@$01@?$MaxPoolFunctor@VCUDAContext@caffe2@@caffe2@QEBA_NHHHPEBUHalf@c10@00PEAU23@PEAVCUDAContext@1@Z) referenced in
 function "public: bool __cdecl caffe2::`anonymous namespace'::CuDNNMaxPoolFunctor::GlobalPoolingBackward<struct c10::H
alf,2>(int,int,int,struct c10::Half const ,struct c10::Half const ,struct c10::Half const ,struct c10::Half ,class
caffe2::CUDAContext *)const " (??$GlobalPoolingBackward@UHalf@c10@@$01@CuDNNMaxPoolFunctor@?A0xb936404a@caffe2@QEBA_NH
HHPEBUHalf@c10@00PEAU34@PEAVCUDAContext@2@Z) [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caff
e2_gpu.vcxproj]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17201

Differential Revision: D14165732

Pulled By: ezyang

fbshipit-source-id: 875fd9a5b2db6f83fc483f6d750d2c011260eb8b

21 files changed:
aten/src/ATen/cuda/CUDAApplyUtils.cuh
aten/src/ATen/native/cuda/Dropout.cu
aten/src/ATen/native/cuda/GridSampler.cu
aten/src/ATen/native/cuda/Loops.cuh
aten/src/ATen/native/cuda/LossCTC.cu
aten/src/ATen/native/cuda/RNN.cu
aten/src/ATen/native/cuda/Reduce.cuh
aten/src/ATen/native/cuda/TensorTransformations.cu
aten/src/THC/THCReduce.cuh
aten/src/THC/THCReduceAll.cuh
aten/src/THC/THCSortUtils.cuh
aten/src/THC/THCTensorTopK.cuh
aten/src/THCUNN/MultiLabelMarginCriterion.cu
aten/src/THCUNN/SpatialClassNLLCriterion.cu
aten/src/THCUNN/SpatialCrossMapLRN.cu
aten/src/THCUNN/SpatialDilatedMaxPooling.cu
aten/src/THCUNN/VolumetricConvolution.cu
aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu
aten/src/THCUNN/im2col.h
c10/macros/Macros.h
caffe2/operators/pool_op_cudnn.cc

index 82a3eb0..724510b 100644 (file)
@@ -273,7 +273,7 @@ template <typename Op,
           int ADims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
                                       IndexType totalElements, const Op op) {
@@ -357,7 +357,7 @@ template <typename Op,
           int ADims, int BDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
@@ -466,7 +466,7 @@ template <typename Op,
           int ADims, int BDims, int CDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply3(detail::TensorInfo<scalar1, IndexType> a,
@@ -589,7 +589,7 @@ template <typename Op,
           int ADims, int BDims, int CDims, int DDims,
           int step>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernelPointwiseApply4(detail::TensorInfo<scalar1, IndexType> a,
index 186aeed..44ce507 100644 (file)
@@ -35,7 +35,7 @@ template <
           typename IndexType,
           int ADims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(256, 8)
+C10_LAUNCH_BOUNDS_2(256, 8)
 #endif
 __global__ void
 fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
index 8dba784..706cd44 100644 (file)
@@ -120,7 +120,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  C10_LAUNCH_BOUNDS(1024)
+  C10_LAUNCH_BOUNDS_1(1024)
   __global__ void grid_sampler_2d_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> input,
@@ -228,7 +228,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  C10_LAUNCH_BOUNDS(1024)
+  C10_LAUNCH_BOUNDS_1(1024)
   __global__ void grid_sampler_3d_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> input,
@@ -392,7 +392,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  C10_LAUNCH_BOUNDS(1024)
+  C10_LAUNCH_BOUNDS_1(1024)
   __global__ void grid_sampler_2d_backward_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> grad_output,
@@ -547,7 +547,7 @@ namespace {
   }
 
   template <typename scalar_t>
-  C10_LAUNCH_BOUNDS(1024)
+  C10_LAUNCH_BOUNDS_1(1024)
   __global__ void grid_sampler_3d_backward_kernel(
       const int nthreads,
       TensorInfo<scalar_t, int> grad_output,
index 9c39e36..dfb9af4 100644 (file)
@@ -36,7 +36,7 @@ static constexpr int launch_bound2 = 4;
 namespace at { namespace native {
 
 template<int nt, int vt, typename func_t>
-C10_LAUNCH_BOUNDS(nt, launch_bound2)
+C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
 __global__ void elementwise_kernel(int N, func_t f) {
   int tid = threadIdx.x;
   int nv = nt * vt;
index 88d414e..c152a51 100644 (file)
@@ -47,7 +47,7 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
                                     const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
@@ -260,7 +260,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
 // alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
 template<typename scalar_t, typename target_t>
 __global__ void
-C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
                                       const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
                                       const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
@@ -366,7 +366,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
@@ -418,7 +418,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
 template<typename scalar_t, typename target_t>
 __global__ void
 #if defined (__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
+C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
 #endif
 ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
                                                      const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
index 8649ee6..a510060 100644 (file)
@@ -82,7 +82,7 @@ namespace kernel {
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void lstm_cell_forward(
             TensorInfo<scalar_t, index_type> input,
@@ -169,7 +169,7 @@ __global__ void lstm_cell_forward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void lstm_cell_backward(
               TensorInfo<scalar_t, index_type> storage,
@@ -234,7 +234,7 @@ __global__ void lstm_cell_backward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void gru_cell_forward(
             TensorInfo<scalar_t, index_type> Input,
@@ -304,7 +304,7 @@ __global__ void gru_cell_forward(
 
 template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void gru_cell_backward(
              TensorInfo<scalar_t, index_type> gradInInput,
index afeed51..891e397 100644 (file)
@@ -172,7 +172,7 @@ struct ReduceConfig {
 std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
 
 template<int nt, typename R>
-C10_LAUNCH_BOUNDS(nt, 4)
+C10_LAUNCH_BOUNDS_2(nt, 4)
 __global__ void reduce_kernel(R reduction) {
   reduction.run();
 }
@@ -410,7 +410,7 @@ struct ReduceOp {
 
     return is_last_block_done_shared;
   }
-  
+
   template <bool can_acc>
   C10_DEVICE arg_t accumulate_in_output(
     out_scalar_t* out, arg_t value,
index da57941..4eb25de 100644 (file)
@@ -17,7 +17,7 @@ constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;
 
 template <typename scalar_t, typename IndexType>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
+C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
 #endif
 __global__ void
 kernel_pointwise_flip_apply2(const cuda::detail::TensorInfo<scalar_t, IndexType> in_tensor_info,
index f750b53..856d322 100644 (file)
@@ -141,7 +141,7 @@ template
    typename FinalizeOp,
    int ADims, int BDims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void kernelReduceNoncontigDim_shared
   (TensorInfo<T, IndexType> out,
@@ -256,7 +256,7 @@ template <typename T,
           typename FinalizeOp,
           int ADims, int BDims>
 #if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(512, 4)
+C10_LAUNCH_BOUNDS_2(512, 4)
 #endif
 __global__ void
 kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
index 65c1b53..1523275 100644 (file)
@@ -27,7 +27,7 @@ template <typename T,
           int ADims>
 __global__ void
 #if defined(__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS(THC_REDUCE_ALL_BLOCK_SIZE)
+C10_LAUNCH_BOUNDS_1(THC_REDUCE_ALL_BLOCK_SIZE)
 #endif
 kernelReduceAll(TensorInfo<T, IndexType> in,
                 IndexType totalElements,
@@ -299,7 +299,7 @@ bool THC_reduceAll(THCState* state,
 
     /*
     Only instantiates the all 1D special case and the fallback all nD case for
-    large (64-bit indexed) tensors to reduce compilation time. 
+    large (64-bit indexed) tensors to reduce compilation time.
     */
     if (inInfo.dims == 1) {
       HANDLE_IN_CASE(uint64_t, 1);
index 8e9bdd2..c60bfe8 100644 (file)
@@ -135,7 +135,7 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
 template <typename K, typename V,
           int KeyDims, int ValueDims,
           typename Comparator, typename IndexType, int Power2SortSize>
-C10_LAUNCH_BOUNDS(1024)
+C10_LAUNCH_BOUNDS_1(1024)
 __global__ void
 bitonicSortKVInPlace(TensorInfo<K, IndexType> keys,
                      IndexType keySlices,
index edf1bdf..2c33054 100644 (file)
@@ -361,7 +361,7 @@ __device__ void radixSelect(DataType* data,
 }
 
 template <typename T, typename IndexType, int Dim, bool Order>
-C10_LAUNCH_BOUNDS(1024)
+C10_LAUNCH_BOUNDS_1(1024)
 __global__ void gatherTopK(TensorInfo<T, IndexType> input,
                            IndexType inputSliceSize,
                            IndexType outputSliceSize, // aka `k`
index 0257ed9..1380b62 100644 (file)
@@ -12,7 +12,7 @@
 
 template <typename Dtype, typename Acctype>
 #if defined(__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
+C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
 #endif
 __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
                                                                    Dtype *input,
@@ -82,7 +82,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output
 
 template <typename Dtype, typename Acctype>
 #if defined(__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
+C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
 #endif
 __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
                                                                       Dtype *gradOutput,
index f50c5c1..60da0a3 100644 (file)
@@ -69,7 +69,7 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(
 
 template <typename T, typename AccumT>
 #if defined(__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS(1024)
+C10_LAUNCH_BOUNDS_1(1024)
 #endif
 __global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
           T *output,
index e1fbd33..c2211fe 100644 (file)
@@ -9,7 +9,7 @@
 template <typename Dtype, typename Acctype>
 __global__ void
 #if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__
-C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
 #endif
 LRNFillScale(const int nthreads, const Dtype* const in,
     const int num, const int channels, const int height,
index e1202e0..dc4f016 100644 (file)
@@ -49,9 +49,9 @@ const int BACKWARD_THREADS = 256;
 
 template <typename Dtype, typename AccType>
 #if defined (__HIP_PLATFORM_HCC__)
-C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 4)
+C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 4)
 #else
-C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 8)
+C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 8)
 #endif
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
     const int64_t* top_mask, const int num, const int channels,
index 3c0b0ca..f2163a3 100644 (file)
@@ -9,7 +9,7 @@
 // Borrowed from Theano
 // Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
 template <typename Dtype>
-__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+__global__ void C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
 im3d2col_kernel(const int64_t n, const Dtype* data_im,
                 const int64_t height, const int64_t width, const int64_t depth,
                 const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
@@ -88,7 +88,7 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
 }
 
 template <typename Dtype, typename Acctype>
-__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
+__global__ void C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
 col2im3d_kernel(const int64_t n, const Dtype* data_col,
                 const int64_t height, const int64_t width, const int64_t depth,
                 const int64_t channels,
index 7e1ab11..a0cd1d3 100644 (file)
@@ -13,7 +13,7 @@
 #include <c10/macros/Macros.h>
 
 template<typename Dtype, typename Acctype>
-C10_LAUNCH_BOUNDS(1024)
+C10_LAUNCH_BOUNDS_1(1024)
 __global__ void caffe_gpu_interp2_kernel(const int n,
     const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
     const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
@@ -81,7 +81,7 @@ __global__ void caffe_gpu_interp2_kernel(const int n,
 
 // Backward (adjoint) operation 1 <- 2 (accumulates)
 template <typename Dtype, typename Acctype>
-C10_LAUNCH_BOUNDS(1024)
+C10_LAUNCH_BOUNDS_1(1024)
 __global__ void caffe_gpu_interp2_kernel_backward(const int n,
     const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
     THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
index 22a4c7a..ce763f4 100644 (file)
@@ -8,7 +8,7 @@
 // Kernel for fast unfold+copy
 // (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
 template <typename Dtype>
-C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
 __global__ void im2col_kernel(const int64_t n, const Dtype* data_im,
                               const int64_t height, const int64_t width,
                               const int64_t ksize_h, const int64_t ksize_w,
@@ -60,7 +60,7 @@ void im2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
 }
 
 template <typename Dtype, typename Acctype>
-C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
+C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
 __global__ void col2im_kernel(const int64_t n, const Dtype* data_col,
                                   const int64_t height, const int64_t width, const int64_t channels,
                                   const int64_t kernel_h, const int64_t kernel_w,
index 884abe6..93aa14b 100644 (file)
@@ -143,12 +143,9 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
 #define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
 #define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
 // C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
-// https://stackoverflow.com/a/8814003 snippet to have macro with an optional argument
 #define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
 #define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
 #define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
-#define C10_LAUNCH_BOUNDS_X(x,max_threads_per_block,min_blocks_per_sm,FUNC, ...) FUNC
-#define C10_LAUNCH_BOUNDS(...) C10_LAUNCH_BOUNDS_X(,##__VA_ARGS__, C10_LAUNCH_BOUNDS_2(__VA_ARGS__), C10_LAUNCH_BOUNDS_1(__VA_ARGS__), C10_LAUNCH_BOUNDS_0(__VA_ARGS__))
 #else
 #define C10_HOST_DEVICE
 #define C10_HOST
index 5ccf479..0e1160a 100644 (file)
@@ -94,7 +94,7 @@ class CuDNNPoolOp final : public ConvPoolOpBase<CUDAContext> {
   }
 
   bool RunOnDevice() override {
-    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
   }
 
   template <typename T>
@@ -235,7 +235,7 @@ class CuDNNPoolGradientOp final : public ConvPoolOpBase<CUDAContext> {
   }
 
   bool RunOnDevice() override {
-    return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
   }
 
   template <typename T>
@@ -359,13 +359,8 @@ struct CuDNNAveragePoolFunctor {
       const T* X,
       T* Y,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for average_pooling.");
-      return false;
-    } else {
       return avg_pool_functor.GlobalPoolingForward<T, kOrder>(
           N, C, HxW, X, Y, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -381,13 +376,8 @@ struct CuDNNAveragePoolFunctor {
       const T* X,
       T* Y,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for average_pooling.");
-      return false;
-    } else {
       return avg_pool_functor.Forward<T, kOrder>(
           N, C, X_dims, Y_dims, kernel, dilation, stride, pads, X, Y, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -400,13 +390,8 @@ struct CuDNNAveragePoolFunctor {
       const T* Y,
       T* dX,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for average_pooling.");
-      return false;
-    } else {
       return avg_pool_functor.GlobalPoolingBackward<T, kOrder>(
           N, C, HxW, dY, X, Y, dX, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -424,10 +409,6 @@ struct CuDNNAveragePoolFunctor {
       const T* Y,
       T* dX,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for average_pooling.");
-      return false;
-    } else {
       return avg_pool_functor.Backward<T, kOrder>(
           N,
           C,
@@ -442,7 +423,6 @@ struct CuDNNAveragePoolFunctor {
           Y,
           dX,
           context);
-    }
   }
 
   const AveragePoolFunctor<CUDAContext> avg_pool_functor;
@@ -469,13 +449,8 @@ struct CuDNNMaxPoolFunctor {
       const T* X,
       T* Y,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for max_pooling.");
-      return false;
-    } else {
       return max_pool_functor.GlobalPoolingForward<T, kOrder>(
           N, C, HxW, X, Y, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -491,13 +466,8 @@ struct CuDNNMaxPoolFunctor {
       const T* X,
       T* Y,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for max_pooling.");
-      return false;
-    } else {
       return max_pool_functor.Forward<T, kOrder>(
           N, C, X_dims, Y_dims, kernel, dilation, stride, pads, X, Y, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -510,13 +480,8 @@ struct CuDNNMaxPoolFunctor {
       const T* Y,
       T* dX,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for max_pooling.");
-      return false;
-    } else {
       return max_pool_functor.GlobalPoolingBackward<T, kOrder>(
           N, C, HxW, dY, X, Y, dX, context);
-    }
   }
 
   template <typename T, StorageOrder kOrder>
@@ -534,10 +499,6 @@ struct CuDNNMaxPoolFunctor {
       const T* Y,
       T* dX,
       CUDAContext* context) const {
-    if (std::is_same<T, at::Half>::value) {
-      CAFFE_THROW("Float16 is not supported for max_pooling.");
-      return false;
-    } else {
       return max_pool_functor.Backward<T, kOrder>(
           N,
           C,
@@ -552,7 +513,6 @@ struct CuDNNMaxPoolFunctor {
           Y,
           dX,
           context);
-    }
   }
 
   const MaxPoolFunctor<CUDAContext> max_pool_functor;