Minor change of group_norm_gradient on GPU (#16307)
authorXiaomeng Yang <yangxm@fb.com>
Fri, 25 Jan 2019 08:54:09 +0000 (00:54 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 25 Jan 2019 09:25:29 +0000 (01:25 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16307

Minor change of group_norm_gradient on GPU

Reviewed By: houseroad

Differential Revision: D13800613

fbshipit-source-id: 9e55f93b1e322efe3fc2d684b9c47c3dbb7a0f48

caffe2/operators/group_norm_op.cu

index c1f86e5..62436d2 100644 (file)
@@ -110,7 +110,7 @@ __global__ void GroupNormForwardNHWCCUDAKernel<float>(
   }
 }
 
-template <typename T>
+template <typename T, int kBlockDimX, int kBlockDimY>
 __global__ void ComputeInternalGradientsNCHWCUDAKernel(
     const int G,
     const int K,
@@ -120,28 +120,31 @@ __global__ void ComputeInternalGradientsNCHWCUDAKernel(
     const T* gamma,
     T* ds,
     T* db) {
-  __shared__ typename BlockReduce<T>::TempStorage ds_storage;
-  __shared__ typename BlockReduce<T>::TempStorage db_storage;
-  const int inner_size = K * HxW;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
   const int n = blockIdx.x;
   const int g = blockIdx.y;
   const int ng = n * G + g;
   T ds_val = 0;
   T db_val = 0;
-  for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
-    const int c = g * K + i / HxW;
-    const int index = ng * inner_size + i;
+  for (int i = threadIdx.x; i < K; i += blockDim.x) {
+    const int c = g * K + i;
+    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
+      const int index = (ng * K + i) * HxW + j;
 #if __CUDA_ARCH__ >= 350
-    ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
-    db_val += __ldg(gamma + c) * __ldg(dY + index);
+      ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
+      db_val += __ldg(gamma + c) * __ldg(dY + index);
 #else
-    ds_val += gamma[c] * dY[index] * X[index];
-    db_val += gamma[c] * dY[index];
+      ds_val += gamma[c] * dY[index] * X[index];
+      db_val += gamma[c] * dY[index];
 #endif
+    }
   }
-  ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
-  db_val = BlockReduce<T>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0) {
+  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
+  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
     ds[ng] = ds_val;
     db[ng] = db_val;
   }
@@ -158,9 +161,9 @@ __global__ void ComputeInternalGradientsNHWCCUDAKernel(
     T* ds,
     T* db) {
   __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
   __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
   const int C = G * K;
   const int n = blockIdx.x;
   const int g = blockIdx.y;
@@ -180,8 +183,8 @@ __global__ void ComputeInternalGradientsNHWCCUDAKernel(
 #endif
     }
   }
-  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(ds_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(db_val);
+  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
+  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
   if (threadIdx.x == 0 && threadIdx.y == 0) {
     ds[ng] = ds_val;
     db[ng] = db_val;
@@ -437,9 +440,23 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
     // Computes dL/ds and dL/db.
     // dL/ds = Sum(dL/dY * gamma * X)
     // dL/db = Sum(dL/dY * gamma)
-    ComputeInternalGradientsNCHWCUDAKernel<float>
-        <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    if (HxW >= 128) {
+      ComputeInternalGradientsNCHWCUDAKernel<float, 1, 128>
+          <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else if (HxW >= 64) {
+      ComputeInternalGradientsNCHWCUDAKernel<float, 2, 64>
+          <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else if (HxW >= 32) {
+      ComputeInternalGradientsNCHWCUDAKernel<float, 4, 32>
+          <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else {
+      ComputeInternalGradientsNCHWCUDAKernel<float, 8, 16>
+          <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    }
 
     // Computes dL/dX.
     const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);