ifdef guard some explicit pragma unrolls (#19018)
authorJohannes M Dieterich <johannes.dieterich@amd.com>
Mon, 8 Apr 2019 16:44:08 +0000 (09:44 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 8 Apr 2019 16:47:23 +0000 (09:47 -0700)
Summary:
the ROCm compiler cannot and will not satisfy them, causing compile time warnings.

Reason being a runtime loop trip count.

Some warnings remain arising from other parts of the ROCm stack - tickets are filed and they will be resolved within these components.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19018

Differential Revision: D14832859

Pulled By: ezyang

fbshipit-source-id: 0d66e4aebe4e56af14dd5e2967d3c374a82be25c

aten/src/THC/THCSortUtils.cuh
aten/src/THCUNN/SpatialDepthwiseConvolution.cu

index 4980ee1..7d39b3b 100644 (file)
@@ -66,7 +66,9 @@ __device__ inline void bitonicSort(K keys[Power2SortSize],
   for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
     bool flag = ((threadIdx.x & (size / 2)) != 0);
 
+#ifndef __HIP_PLATFORM_HCC__
 #pragma unroll
+#endif
     for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
 
       __syncthreads();
index baf7610..2ee0417 100644 (file)
@@ -76,7 +76,9 @@ __global__ void spatialDepthwiseConvolutionUpdateOutput(
 
     AccT value = biasEnabled ? ScalarConvert<T, AccT>::to(bias.data()[c]) : ScalarConvert<int, AccT>::to(0);
     const IndexType offset0 = (n * inputChannels + inputChannel) * inputHeight * inputWidth;
+#ifndef __HIP_PLATFORM_HCC__
 #pragma unroll
+#endif
     for (int kH = 0; kH < KH_LIMIT; ++kH) {
 #ifndef __HIP_PLATFORM_HCC__
 #pragma unroll
@@ -136,7 +138,9 @@ __global__ void spatialDepthwiseConvolutionUpdateGradInput(
 
     AccT value = ScalarConvert<int, AccT>::to(0);
 
+#ifndef __HIP_PLATFORM_HCC__
 #pragma unroll
+#endif
     for (int multiplier = 0; multiplier < depthwiseMultiplier; ++multiplier) {
       int och = (c * depthwiseMultiplier) + multiplier;
       int weightOffset = och * kernelHeight * kernelWidth;