optimize group_norm (#16216)
authorXiaomeng Yang <yangxm@fb.com>
Thu, 24 Jan 2019 07:55:06 +0000 (23:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 24 Jan 2019 07:57:45 +0000 (23:57 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16216

Optimize GroupNormOp

Reviewed By: houseroad

Differential Revision: D13754145

fbshipit-source-id: 650f64c81486c6c9d276f2e3325392d5838751ba

caffe2/operators/group_norm_op.cu

index dc8f9d8..c1f86e5 100644 (file)
@@ -11,7 +11,7 @@
 #include <cub/block/block_reduce.cuh>
 
 #include "caffe2/core/context_gpu.h"
-#include "caffe2/utils/math_utils.h"
+#include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
@@ -20,47 +20,41 @@ namespace {
 template <typename T>
 using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
 
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+    BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
 template <typename T>
 __global__ void ComputeFusedParamsCUDAKernel(
-    const int N,
     const int G,
-    const int D,
+    const int K,
     const T* mu,
     const T* rsig,
     const T* gamma,
     const T* beta,
     T* scale,
     T* bias) {
-  const int outer_size = N * G;
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-    const int g = i % G;
+  const int n = blockIdx.x;
+  const int g = blockIdx.y;
+  const int i_mu = n * G + g;
+  for (int i = threadIdx.x; i < K; i += blockDim.x) {
+    const int index = i_mu * K + i;
+    const int i_gamma = g * K + i;
 #if __CUDA_ARCH__ >= 350
-    const T mu_val = __ldg(mu + i);
-    const T rsig_val = __ldg(rsig + i);
+    const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu);
+    scale[index] = scale_val;
+    bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu);
 #else
-    const T mu_val = mu[i];
-    const T rsig_val = rsig[i];
+    const T scale_val = gamma[i_gamma] * rsig[i_mu];
+    scale[index] = scale_val;
+    bias[index] = beta[i_gamma] - scale_val * mu[i_mu];
 #endif
-    for (int j = threadIdx.x; j < D; j += blockDim.x) {
-      const int index = i * D + j;
-      const int i_gamma = g * D + j;
-#if __CUDA_ARCH__ >= 350
-      const T scale_val = __ldg(gamma + i_gamma) * rsig_val;
-      scale[index] = scale_val;
-      bias[index] = __ldg(beta + i_gamma) - scale_val * mu_val;
-#else
-      const T scale_val = gamma[i_gamma] * rsig_val;
-      scale[index] = scale_val;
-      bias[index] = beta[i_gamma] - scale_val * mu_val;
-#endif
-    }
   }
 }
 
-template <typename T, StorageOrder kOrder>
-__global__ void GroupNormForwardCUDAKernel(
-    const int N,
-    const int C,
+template <typename T>
+__global__ void GroupNormForwardNCHWCUDAKernel(
+    const int K,
     const int HxW,
     const T* X,
     const T* scale,
@@ -68,97 +62,129 @@ __global__ void GroupNormForwardCUDAKernel(
     T* Y);
 
 template <>
-__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>(
-    const int N,
-    const int C,
+__global__ void GroupNormForwardNCHWCUDAKernel<float>(
+    const int W,
     const int HxW,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int outer_size = N * C;
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-#if __CUDA_ARCH__ >= 350
-    const float scale_val = __ldg(scale + i);
-    const float bias_val = __ldg(bias + i);
-#else
-    const float scale_val = scale[i];
-    const float bias_val = bias[i];
-#endif
-    for (int j = threadIdx.x; j < HxW; j += blockDim.x) {
-      const int index = i * HxW + j;
+  const int nc = blockIdx.x / W;
+  const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (hw < HxW) {
+    const int index = nc * HxW + hw;
 #if __CUDA_ARCH__ >= 350
-      Y[index] = __ldg(X + index) * scale_val + bias_val;
+    Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
 #else
-      Y[index] = X[index] * scale_val + bias_val;
+    Y[index] = fmaf(X[index], scale[nc], bias[nc]);
 #endif
-    }
   }
 }
 
+template <typename T>
+__global__ void GroupNormForwardNHWCCUDAKernel(
+    const int C,
+    const int HxW,
+    const T* X,
+    const T* scale,
+    const T* bias,
+    T* Y);
+
 template <>
-__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>(
-    const int N,
+__global__ void GroupNormForwardNHWCCUDAKernel<float>(
     const int C,
     const int HxW,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int outer_size = N * HxW;
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-    const int n = i / HxW;
-    for (int j = threadIdx.x; j < C; j += blockDim.x) {
-      const int index = i * C + j;
-      const int i_scale = n * C + j;
+  const int n = blockIdx.x / HxW;
+  for (int c = threadIdx.x; c < C; c += blockDim.x) {
+    const int index = blockIdx.x * C + c;
+    const int nc = n * C + c;
 #if __CUDA_ARCH__ >= 350
-      Y[index] =
-          __ldg(X + index) * __ldg(scale + i_scale) + __ldg(bias + i_scale);
+    Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
 #else
-      Y[index] = X[index] * scale[i_scale] + bias[i_scale];
+    Y[index] = fmaf(X[index], scale[nc], bias[nc]);
 #endif
-    }
   }
 }
 
-template <typename T, StorageOrder kOrder>
-__global__ void ComputeInternalGradientsCUDAKernel(
-    const int N,
+template <typename T>
+__global__ void ComputeInternalGradientsNCHWCUDAKernel(
     const int G,
-    const int D,
+    const int K,
     const int HxW,
     const T* dY,
     const T* X,
     const T* gamma,
     T* ds,
     T* db) {
-  const int outer_size = N * G;
-  const int inner_size = D * HxW;
   __shared__ typename BlockReduce<T>::TempStorage ds_storage;
   __shared__ typename BlockReduce<T>::TempStorage db_storage;
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-    T ds_val = 0;
-    T db_val = 0;
-    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
-      const int i_gamma = i % G * D + j / HxW;
-      const int index = kOrder == StorageOrder::NCHW
-          ? i * inner_size + j
-          : (i / G * HxW + j % HxW) * G * D + i_gamma;
+  const int inner_size = K * HxW;
+  const int n = blockIdx.x;
+  const int g = blockIdx.y;
+  const int ng = n * G + g;
+  T ds_val = 0;
+  T db_val = 0;
+  for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
+    const int c = g * K + i / HxW;
+    const int index = ng * inner_size + i;
 #if __CUDA_ARCH__ >= 350
-      ds_val += __ldg(gamma + i_gamma) * __ldg(dY + index) * __ldg(X + index);
-      db_val += __ldg(gamma + i_gamma) * __ldg(dY + index);
+    ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
+    db_val += __ldg(gamma + c) * __ldg(dY + index);
 #else
-      ds_val += gamma[i_gamma] * dY[index] * X[index];
-      db_val += gamma[i_gamma] * dY[index];
+    ds_val += gamma[c] * dY[index] * X[index];
+    db_val += gamma[c] * dY[index];
+#endif
+  }
+  ds_val = BlockReduce<T>(ds_storage).Sum(ds_val);
+  db_val = BlockReduce<T>(db_storage).Sum(db_val);
+  if (threadIdx.x == 0) {
+    ds[ng] = ds_val;
+    db[ng] = db_val;
+  }
+}
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void ComputeInternalGradientsNHWCCUDAKernel(
+    const int G,
+    const int K,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* gamma,
+    T* ds,
+    T* db) {
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
+  const int C = G * K;
+  const int n = blockIdx.x;
+  const int g = blockIdx.y;
+  const int ng = n * G + g;
+  T ds_val = 0;
+  T db_val = 0;
+  for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
+    for (int j = threadIdx.y; j < K; j += blockDim.y) {
+      const int c = g * K + j;
+      const int index = (n * HxW + i) * C + c;
+#if __CUDA_ARCH__ >= 350
+      ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
+      db_val += __ldg(gamma + c) * __ldg(dY + index);
+#else
+      ds_val += gamma[c] * dY[index] * X[index];
+      db_val += gamma[c] * dY[index];
 #endif
     }
-    ds_val = BlockReduce<T>(ds_storage).Reduce(ds_val, cub::Sum());
-    db_val = BlockReduce<T>(db_storage).Reduce(db_val, cub::Sum());
-    if (threadIdx.x == 0) {
-      ds[i] = ds_val;
-      db[i] = db_val;
-    }
-    __syncthreads();
+  }
+  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(ds_val);
+  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(db_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+    ds[ng] = ds_val;
+    db[ng] = db_val;
   }
 }
 
@@ -173,11 +199,50 @@ __global__ void ComputeInternalGradientsCUDAKernel(
 // db/dX = -u * drsig/dX - rsig * dmu/dX
 // drsig/dX = -rsig^3 * (X - mu) / n
 // dmu/dX = 1 / n
-template <typename T, StorageOrder kOrder>
-__global__ void GroupNormBackwardCUDAKernel(
-    const int size,
+template <typename T>
+__global__ void GroupNormBackwardNCHWCUDAKernel(
+    const int G,
+    const int K,
+    const int W,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* mu,
+    const T* rsig,
+    const T* gamma,
+    const T* ds,
+    const T* db,
+    T* dX) {
+  const int C = G * K;
+  const T denom = T(1) / static_cast<T>(K * HxW);
+  const int nc = blockIdx.x / W;
+  const int n = nc / C;
+  const int c = nc % C;
+  const int g = c / K;
+  const int ng = n * G + g;
+  const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  const int index = nc * HxW + hw;
+  if (hw < HxW) {
+#if __CUDA_ARCH__ >= 350
+    const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
+        (__ldg(X + index) - __ldg(mu + ng)) *
+        math::utils::Cube<T>(__ldg(rsig + ng));
+    const T v = __ldg(db + ng) * __ldg(rsig + ng);
+    dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
+        (u - v) * denom;
+#else
+    const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
+        math::utils::Cube<T>(rsig[ng]);
+    const T v = db[ng] * rsig[ng];
+    dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
+#endif
+  }
+}
+
+template <typename T>
+__global__ void GroupNormBackwardNHWCCUDAKernel(
     const int G,
-    const int D,
+    const int K,
     const int HxW,
     const T* dY,
     const T* X,
@@ -187,34 +252,36 @@ __global__ void GroupNormBackwardCUDAKernel(
     const T* ds,
     const T* db,
     T* dX) {
-  const int C = G * D;
-  const T denom = T(1) / static_cast<T>(D * HxW);
-  CUDA_1D_KERNEL_LOOP(i, size) {
-    const int i_mu = kOrder == StorageOrder::NCHW
-        ? i / (D * HxW)
-        : i / (C * HxW) * G + (i / D % G);
-    const int i_gamma = kOrder == StorageOrder::NCHW ? (i / HxW) % C : i % C;
+  const int C = G * K;
+  const T denom = T(1) / static_cast<T>(K * HxW);
+  const int x = blockIdx.x;
+  const int g = blockIdx.y;
+  const int n = x / HxW;
+  const int ng = n * G + g;
+  for (int i = threadIdx.x; i < K; i += blockDim.x) {
+    const int c = g * K + i;
+    const int index = x * C + c;
 #if __CUDA_ARCH__ >= 350
-    const T u = (__ldg(db + i_mu) * __ldg(mu + i_mu) - __ldg(ds + i_mu)) *
-        (__ldg(X + i) - __ldg(mu + i_mu)) *
-        math::utils::Cube<T>(__ldg(rsig + i_mu));
-    const T v = __ldg(db + i_mu) * __ldg(rsig + i_mu);
-    dX[i] = __ldg(gamma + i_gamma) * __ldg(dY + i) * __ldg(rsig + i_mu) +
+    const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
+        (__ldg(X + index) - __ldg(mu + ng)) *
+        math::utils::Cube<T>(__ldg(rsig + ng));
+    const T v = __ldg(db + ng) * __ldg(rsig + ng);
+    dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
         (u - v) * denom;
 #else
-    const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
-        math::utils::Cube<T>(rsig[i_mu]);
-    const T v = db[i_mu] * rsig[i_mu];
-    dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
+    const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
+        math::utils::Cube<T>(rsig[ng]);
+    const T v = db[ng] * rsig[ng];
+    dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
 #endif
   }
 }
 
-template <typename T, StorageOrder kOrder>
-__global__ void GammaBetaBackwardCUDAKernel(
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void GammaBetaBackwardNCHWCUDAKernel(
     const int N,
     const int G,
-    const int D,
+    const int K,
     const int HxW,
     const T* dY,
     const T* X,
@@ -222,35 +289,77 @@ __global__ void GammaBetaBackwardCUDAKernel(
     const T* rsig,
     T* dgamma,
     T* dbeta) {
-  const int outer_size = G * D;
-  const int inner_size = N * HxW;
-  __shared__ typename BlockReduce<T>::TempStorage dg_storage;
-  __shared__ typename BlockReduce<T>::TempStorage db_storage;
-  for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
-    T dg_val = 0;
-    T db_val = 0;
-    for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
-      const int n = j / HxW;
-      const int index = kOrder == StorageOrder::NCHW
-          ? (n * outer_size + i) * HxW + j % HxW
-          : j * outer_size + i;
-      const int i_mu = n * G + i / D;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+  const int C = G * K;
+  const int c = blockIdx.x;
+  const int g = c / K;
+  T dg_val = 0;
+  T db_val = 0;
+  for (int i = threadIdx.x; i < N; i += blockDim.x) {
+    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
+      const int index = (i * C + c) * HxW + j;
+      const int ng = i * G + g;
 #if __CUDA_ARCH__ >= 350
-      dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + i_mu)) *
-          __ldg(rsig + i_mu);
+      dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
+          __ldg(rsig + ng);
       db_val += __ldg(dY + index);
 #else
-      dg_val += dY[index] * (X[index] - mu[i_mu]) * rsig[i_mu];
+      dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
       db_val += dY[index];
 #endif
     }
-    dg_val = BlockReduce<T>(dg_storage).Reduce(dg_val, cub::Sum());
-    db_val = BlockReduce<T>(db_storage).Reduce(db_val, cub::Sum());
-    if (threadIdx.x == 0) {
-      dgamma[i] = dg_val;
-      dbeta[i] = db_val;
+  }
+  dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
+  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+    dgamma[c] = dg_val;
+    dbeta[c] = db_val;
+  }
+}
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+__global__ void GammaBetaBackwardNHWCCUDAKernel(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const T* dY,
+    const T* X,
+    const T* mu,
+    const T* rsig,
+    T* dgamma,
+    T* dbeta) {
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
+  __shared__
+      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+  const int C = G * K;
+  const int c = blockIdx.x;
+  const int g = c / K;
+  T dg_val = 0;
+  T db_val = 0;
+  for (int i = threadIdx.x; i < N; i += blockDim.x) {
+    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
+      const int index = (i * HxW + j) * C + c;
+      const int ng = i * G + g;
+#if __CUDA_ARCH__ >= 350
+      dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
+          __ldg(rsig + ng);
+      db_val += __ldg(dY + index);
+#else
+      dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
+      db_val += dY[index];
+#endif
     }
-    __syncthreads();
+  }
+  dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
+  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
+  if (threadIdx.x == 0 && threadIdx.y == 0) {
+    dgamma[c] = dg_val;
+    dbeta[c] = db_val;
   }
 }
 
@@ -260,7 +369,7 @@ template <>
 void GroupNormOp<float, CUDAContext>::ComputeFusedParams(
     const int N,
     const int G,
-    const int D,
+    const int K,
     const float* mu,
     const float* rsig,
     const float* gamma,
@@ -268,10 +377,8 @@ void GroupNormOp<float, CUDAContext>::ComputeFusedParams(
     float* scale,
     float* bias) {
   ComputeFusedParamsCUDAKernel<float>
-      <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(N, G, D, mu, rsig, gamma, beta, scale, bias);
+      <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          G, K, mu, rsig, gamma, beta, scale, bias);
 }
 
 template <>
@@ -283,11 +390,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNCHW(
     const float* scale,
     const float* bias,
     float* Y) {
-  GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>
-      <<<std::min(N * C, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(N, C, HxW, X, scale, bias, Y);
+  const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+  GroupNormForwardNCHWCUDAKernel<float>
+      <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          W, HxW, X, scale, bias, Y);
 }
 
 template <>
@@ -299,11 +405,9 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
     const float* scale,
     const float* bias,
     float* Y) {
-  GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>
-      <<<std::min(N * HxW, CAFFE_MAXIMUM_NUM_BLOCKS),
-         CAFFE_CUDA_NUM_THREADS,
-         0,
-         context_.cuda_stream()>>>(N, C, HxW, X, scale, bias, Y);
+  GroupNormForwardNHWCCUDAKernel<float>
+      <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          C, HxW, X, scale, bias, Y);
 }
 
 // Math:
@@ -314,7 +418,7 @@ template <>
 bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
     const int N,
     const int G,
-    const int D,
+    const int K,
     const int HxW,
     const float* dY_data,
     const float* X_data,
@@ -324,34 +428,26 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
     float* dX_data,
     float* dgamma_data,
     float* dbeta_data) {
-  const int size = N * G * D * HxW;
-  const int C = G * D;
-  ReinitializeTensor(
-      &ds_, {N, G}, at::dtype<float>().device(CUDA));
-  ReinitializeTensor(
-      &db_, {N, G}, at::dtype<float>().device(CUDA));
+  const int C = G * K;
+  ReinitializeTensor(&ds_, {N, G}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&db_, {N, G}, at::dtype<float>().device(CUDA));
   float* ds_data = ds_.mutable_data<float>();
   float* db_data = db_.mutable_data<float>();
   if (order_ == StorageOrder::NCHW) {
     // Computes dL/ds and dL/db.
     // dL/ds = Sum(dL/dY * gamma * X)
     // dL/db = Sum(dL/dY * gamma)
-    ComputeInternalGradientsCUDAKernel<float, StorageOrder::NCHW>
-        <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    ComputeInternalGradientsNCHWCUDAKernel<float>
+        <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+            G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
 
     // Computes dL/dX.
-    GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>
-        <<<CAFFE_GET_BLOCKS(size),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            size,
+    const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+    GroupNormBackwardNCHWCUDAKernel<float>
+        <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
             G,
-            D,
+            K,
+            W,
             HxW,
             dY_data,
             X_data,
@@ -363,41 +459,89 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
             dX_data);
 
     // Computes dL/dgamma and dL/dbeta.
-    GammaBetaBackwardCUDAKernel<float, StorageOrder::NCHW>
-        <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            N,
-            G,
-            D,
-            HxW,
-            dY_data,
-            X_data,
-            mu_data,
-            rsig_data,
-            dgamma_data,
-            dbeta_data);
+    if (HxW >= 128) {
+      GammaBetaBackwardNCHWCUDAKernel<float, 1, 128>
+          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else if (HxW >= 64) {
+      GammaBetaBackwardNCHWCUDAKernel<float, 2, 64>
+          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else if (HxW >= 32) {
+      GammaBetaBackwardNCHWCUDAKernel<float, 4, 32>
+          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else {
+      GammaBetaBackwardNCHWCUDAKernel<float, 8, 16>
+          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    }
   } else {
     // Computes dL/ds and dL/db.
     // dL/ds = Sum(dL/dY * gamma * X)
     // dL/db = Sum(dL/dY * gamma)
-    ComputeInternalGradientsCUDAKernel<float, StorageOrder::NHWC>
-        <<<std::min(N * G, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    if (K >= 128) {
+      ComputeInternalGradientsNHWCCUDAKernel<float, 1, 128>
+          <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else if (K >= 64) {
+      ComputeInternalGradientsNHWCCUDAKernel<float, 2, 64>
+          <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else if (K >= 32) {
+      ComputeInternalGradientsNHWCCUDAKernel<float, 4, 32>
+          <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    } else {
+      ComputeInternalGradientsNHWCCUDAKernel<float, 8, 16>
+          <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
+              G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
+    }
 
     // Computes dL/dX.
-    GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>
-        <<<CAFFE_GET_BLOCKS(size),
+    GroupNormBackwardNHWCCUDAKernel<float>
+        <<<dim3(N * HxW, G),
            CAFFE_CUDA_NUM_THREADS,
            0,
            context_.cuda_stream()>>>(
-            size,
             G,
-            D,
+            K,
             HxW,
             dY_data,
             X_data,
@@ -409,21 +553,59 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
             dX_data);
 
     // Computes dL/dgamma and dL/dbeta.
-    GammaBetaBackwardCUDAKernel<float, StorageOrder::NHWC>
-        <<<std::min(C, CAFFE_MAXIMUM_NUM_BLOCKS),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            N,
-            G,
-            D,
-            HxW,
-            dY_data,
-            X_data,
-            mu_data,
-            rsig_data,
-            dgamma_data,
-            dbeta_data);
+    if (HxW >= 128) {
+      GammaBetaBackwardNHWCCUDAKernel<float, 1, 128>
+          <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else if (HxW >= 64) {
+      GammaBetaBackwardNHWCCUDAKernel<float, 2, 64>
+          <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else if (HxW >= 32) {
+      GammaBetaBackwardNHWCCUDAKernel<float, 4, 32>
+          <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    } else {
+      GammaBetaBackwardNHWCCUDAKernel<float, 8, 16>
+          <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
+              N,
+              G,
+              K,
+              HxW,
+              dY_data,
+              X_data,
+              mu_data,
+              rsig_data,
+              dgamma_data,
+              dbeta_data);
+    }
   }
   return true;
 }