From 60241e94b3e71f7637230648dcbbd068b71030b4 Mon Sep 17 00:00:00 2001 From: Xiaomeng Yang Date: Wed, 23 Jan 2019 23:55:06 -0800 Subject: [PATCH] optimize group_norm (#16216) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16216 Optimize GroupNormOp Reviewed By: houseroad Differential Revision: D13754145 fbshipit-source-id: 650f64c81486c6c9d276f2e3325392d5838751ba --- caffe2/operators/group_norm_op.cu | 584 +++++++++++++++++++++++++------------- 1 file changed, 383 insertions(+), 201 deletions(-) diff --git a/caffe2/operators/group_norm_op.cu b/caffe2/operators/group_norm_op.cu index dc8f9d8..c1f86e5 100644 --- a/caffe2/operators/group_norm_op.cu +++ b/caffe2/operators/group_norm_op.cu @@ -11,7 +11,7 @@ #include #include "caffe2/core/context_gpu.h" -#include "caffe2/utils/math_utils.h" +#include "caffe2/utils/math.h" namespace caffe2 { @@ -20,47 +20,41 @@ namespace { template using BlockReduce = cub::BlockReduce; +template +using BlockReduce2D = cub:: + BlockReduce; + template __global__ void ComputeFusedParamsCUDAKernel( - const int N, const int G, - const int D, + const int K, const T* mu, const T* rsig, const T* gamma, const T* beta, T* scale, T* bias) { - const int outer_size = N * G; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - const int g = i % G; + const int n = blockIdx.x; + const int g = blockIdx.y; + const int i_mu = n * G + g; + for (int i = threadIdx.x; i < K; i += blockDim.x) { + const int index = i_mu * K + i; + const int i_gamma = g * K + i; #if __CUDA_ARCH__ >= 350 - const T mu_val = __ldg(mu + i); - const T rsig_val = __ldg(rsig + i); + const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu); + scale[index] = scale_val; + bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu); #else - const T mu_val = mu[i]; - const T rsig_val = rsig[i]; + const T scale_val = gamma[i_gamma] * rsig[i_mu]; + scale[index] = scale_val; + bias[index] = beta[i_gamma] - scale_val * mu[i_mu]; #endif - for (int j = threadIdx.x; j < D; j += blockDim.x) { - const int index = i * D + j; - const int i_gamma = g * D + j; -#if __CUDA_ARCH__ >= 350 - const T scale_val = __ldg(gamma + i_gamma) * rsig_val; - scale[index] = scale_val; - bias[index] = __ldg(beta + i_gamma) - scale_val * mu_val; -#else - const T scale_val = gamma[i_gamma] * rsig_val; - scale[index] = scale_val; - bias[index] = beta[i_gamma] - scale_val * mu_val; -#endif - } } } -template -__global__ void GroupNormForwardCUDAKernel( - const int N, - const int C, +template +__global__ void GroupNormForwardNCHWCUDAKernel( + const int K, const int HxW, const T* X, const T* scale, @@ -68,97 +62,129 @@ __global__ void GroupNormForwardCUDAKernel( T* Y); template <> -__global__ void GroupNormForwardCUDAKernel( - const int N, - const int C, +__global__ void GroupNormForwardNCHWCUDAKernel( + const int W, const int HxW, const float* X, const float* scale, const float* bias, float* Y) { - const int outer_size = N * C; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { -#if __CUDA_ARCH__ >= 350 - const float scale_val = __ldg(scale + i); - const float bias_val = __ldg(bias + i); -#else - const float scale_val = scale[i]; - const float bias_val = bias[i]; -#endif - for (int j = threadIdx.x; j < HxW; j += blockDim.x) { - const int index = i * HxW + j; + const int nc = blockIdx.x / W; + const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (hw < HxW) { + const int index = nc * HxW + hw; #if __CUDA_ARCH__ >= 350 - Y[index] = __ldg(X + index) * scale_val + bias_val; + Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc)); #else - Y[index] = X[index] * scale_val + bias_val; + Y[index] = fmaf(X[index], scale[nc], bias[nc]); #endif - } } } +template +__global__ void GroupNormForwardNHWCCUDAKernel( + const int C, + const int HxW, + const T* X, + const T* scale, + const T* bias, + T* Y); + template <> -__global__ void GroupNormForwardCUDAKernel( - const int N, +__global__ void GroupNormForwardNHWCCUDAKernel( const int C, const int HxW, const float* X, const float* scale, const float* bias, float* Y) { - const int outer_size = N * HxW; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - const int n = i / HxW; - for (int j = threadIdx.x; j < C; j += blockDim.x) { - const int index = i * C + j; - const int i_scale = n * C + j; + const int n = blockIdx.x / HxW; + for (int c = threadIdx.x; c < C; c += blockDim.x) { + const int index = blockIdx.x * C + c; + const int nc = n * C + c; #if __CUDA_ARCH__ >= 350 - Y[index] = - __ldg(X + index) * __ldg(scale + i_scale) + __ldg(bias + i_scale); + Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc)); #else - Y[index] = X[index] * scale[i_scale] + bias[i_scale]; + Y[index] = fmaf(X[index], scale[nc], bias[nc]); #endif - } } } -template -__global__ void ComputeInternalGradientsCUDAKernel( - const int N, +template +__global__ void ComputeInternalGradientsNCHWCUDAKernel( const int G, - const int D, + const int K, const int HxW, const T* dY, const T* X, const T* gamma, T* ds, T* db) { - const int outer_size = N * G; - const int inner_size = D * HxW; __shared__ typename BlockReduce::TempStorage ds_storage; __shared__ typename BlockReduce::TempStorage db_storage; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - T ds_val = 0; - T db_val = 0; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int i_gamma = i % G * D + j / HxW; - const int index = kOrder == StorageOrder::NCHW - ? i * inner_size + j - : (i / G * HxW + j % HxW) * G * D + i_gamma; + const int inner_size = K * HxW; + const int n = blockIdx.x; + const int g = blockIdx.y; + const int ng = n * G + g; + T ds_val = 0; + T db_val = 0; + for (int i = threadIdx.x; i < inner_size; i += blockDim.x) { + const int c = g * K + i / HxW; + const int index = ng * inner_size + i; #if __CUDA_ARCH__ >= 350 - ds_val += __ldg(gamma + i_gamma) * __ldg(dY + index) * __ldg(X + index); - db_val += __ldg(gamma + i_gamma) * __ldg(dY + index); + ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index); + db_val += __ldg(gamma + c) * __ldg(dY + index); #else - ds_val += gamma[i_gamma] * dY[index] * X[index]; - db_val += gamma[i_gamma] * dY[index]; + ds_val += gamma[c] * dY[index] * X[index]; + db_val += gamma[c] * dY[index]; +#endif + } + ds_val = BlockReduce(ds_storage).Sum(ds_val); + db_val = BlockReduce(db_storage).Sum(db_val); + if (threadIdx.x == 0) { + ds[ng] = ds_val; + db[ng] = db_val; + } +} + +template +__global__ void ComputeInternalGradientsNHWCCUDAKernel( + const int G, + const int K, + const int HxW, + const T* dY, + const T* X, + const T* gamma, + T* ds, + T* db) { + __shared__ + typename BlockReduce2D::TempStorage m_storage; + __shared__ + typename BlockReduce2D::TempStorage v_storage; + const int C = G * K; + const int n = blockIdx.x; + const int g = blockIdx.y; + const int ng = n * G + g; + T ds_val = 0; + T db_val = 0; + for (int i = threadIdx.x; i < HxW; i += blockDim.x) { + for (int j = threadIdx.y; j < K; j += blockDim.y) { + const int c = g * K + j; + const int index = (n * HxW + i) * C + c; +#if __CUDA_ARCH__ >= 350 + ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index); + db_val += __ldg(gamma + c) * __ldg(dY + index); +#else + ds_val += gamma[c] * dY[index] * X[index]; + db_val += gamma[c] * dY[index]; #endif } - ds_val = BlockReduce(ds_storage).Reduce(ds_val, cub::Sum()); - db_val = BlockReduce(db_storage).Reduce(db_val, cub::Sum()); - if (threadIdx.x == 0) { - ds[i] = ds_val; - db[i] = db_val; - } - __syncthreads(); + } + ds_val = BlockReduce2D(m_storage).Sum(ds_val); + db_val = BlockReduce2D(v_storage).Sum(db_val); + if (threadIdx.x == 0 && threadIdx.y == 0) { + ds[ng] = ds_val; + db[ng] = db_val; } } @@ -173,11 +199,50 @@ __global__ void ComputeInternalGradientsCUDAKernel( // db/dX = -u * drsig/dX - rsig * dmu/dX // drsig/dX = -rsig^3 * (X - mu) / n // dmu/dX = 1 / n -template -__global__ void GroupNormBackwardCUDAKernel( - const int size, +template +__global__ void GroupNormBackwardNCHWCUDAKernel( + const int G, + const int K, + const int W, + const int HxW, + const T* dY, + const T* X, + const T* mu, + const T* rsig, + const T* gamma, + const T* ds, + const T* db, + T* dX) { + const int C = G * K; + const T denom = T(1) / static_cast(K * HxW); + const int nc = blockIdx.x / W; + const int n = nc / C; + const int c = nc % C; + const int g = c / K; + const int ng = n * G + g; + const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + const int index = nc * HxW + hw; + if (hw < HxW) { +#if __CUDA_ARCH__ >= 350 + const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) * + (__ldg(X + index) - __ldg(mu + ng)) * + math::utils::Cube(__ldg(rsig + ng)); + const T v = __ldg(db + ng) * __ldg(rsig + ng); + dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) + + (u - v) * denom; +#else + const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) * + math::utils::Cube(rsig[ng]); + const T v = db[ng] * rsig[ng]; + dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom; +#endif + } +} + +template +__global__ void GroupNormBackwardNHWCCUDAKernel( const int G, - const int D, + const int K, const int HxW, const T* dY, const T* X, @@ -187,34 +252,36 @@ __global__ void GroupNormBackwardCUDAKernel( const T* ds, const T* db, T* dX) { - const int C = G * D; - const T denom = T(1) / static_cast(D * HxW); - CUDA_1D_KERNEL_LOOP(i, size) { - const int i_mu = kOrder == StorageOrder::NCHW - ? i / (D * HxW) - : i / (C * HxW) * G + (i / D % G); - const int i_gamma = kOrder == StorageOrder::NCHW ? (i / HxW) % C : i % C; + const int C = G * K; + const T denom = T(1) / static_cast(K * HxW); + const int x = blockIdx.x; + const int g = blockIdx.y; + const int n = x / HxW; + const int ng = n * G + g; + for (int i = threadIdx.x; i < K; i += blockDim.x) { + const int c = g * K + i; + const int index = x * C + c; #if __CUDA_ARCH__ >= 350 - const T u = (__ldg(db + i_mu) * __ldg(mu + i_mu) - __ldg(ds + i_mu)) * - (__ldg(X + i) - __ldg(mu + i_mu)) * - math::utils::Cube(__ldg(rsig + i_mu)); - const T v = __ldg(db + i_mu) * __ldg(rsig + i_mu); - dX[i] = __ldg(gamma + i_gamma) * __ldg(dY + i) * __ldg(rsig + i_mu) + + const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) * + (__ldg(X + index) - __ldg(mu + ng)) * + math::utils::Cube(__ldg(rsig + ng)); + const T v = __ldg(db + ng) * __ldg(rsig + ng); + dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) + (u - v) * denom; #else - const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) * - math::utils::Cube(rsig[i_mu]); - const T v = db[i_mu] * rsig[i_mu]; - dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom; + const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) * + math::utils::Cube(rsig[ng]); + const T v = db[ng] * rsig[ng]; + dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom; #endif } } -template -__global__ void GammaBetaBackwardCUDAKernel( +template +__global__ void GammaBetaBackwardNCHWCUDAKernel( const int N, const int G, - const int D, + const int K, const int HxW, const T* dY, const T* X, @@ -222,35 +289,77 @@ __global__ void GammaBetaBackwardCUDAKernel( const T* rsig, T* dgamma, T* dbeta) { - const int outer_size = G * D; - const int inner_size = N * HxW; - __shared__ typename BlockReduce::TempStorage dg_storage; - __shared__ typename BlockReduce::TempStorage db_storage; - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - T dg_val = 0; - T db_val = 0; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int n = j / HxW; - const int index = kOrder == StorageOrder::NCHW - ? (n * outer_size + i) * HxW + j % HxW - : j * outer_size + i; - const int i_mu = n * G + i / D; + __shared__ + typename BlockReduce2D::TempStorage dg_storage; + __shared__ + typename BlockReduce2D::TempStorage db_storage; + const int C = G * K; + const int c = blockIdx.x; + const int g = c / K; + T dg_val = 0; + T db_val = 0; + for (int i = threadIdx.x; i < N; i += blockDim.x) { + for (int j = threadIdx.y; j < HxW; j += blockDim.y) { + const int index = (i * C + c) * HxW + j; + const int ng = i * G + g; #if __CUDA_ARCH__ >= 350 - dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + i_mu)) * - __ldg(rsig + i_mu); + dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) * + __ldg(rsig + ng); db_val += __ldg(dY + index); #else - dg_val += dY[index] * (X[index] - mu[i_mu]) * rsig[i_mu]; + dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng]; db_val += dY[index]; #endif } - dg_val = BlockReduce(dg_storage).Reduce(dg_val, cub::Sum()); - db_val = BlockReduce(db_storage).Reduce(db_val, cub::Sum()); - if (threadIdx.x == 0) { - dgamma[i] = dg_val; - dbeta[i] = db_val; + } + dg_val = BlockReduce2D(dg_storage).Sum(dg_val); + db_val = BlockReduce2D(db_storage).Sum(db_val); + if (threadIdx.x == 0 && threadIdx.y == 0) { + dgamma[c] = dg_val; + dbeta[c] = db_val; + } +} + +template +__global__ void GammaBetaBackwardNHWCCUDAKernel( + const int N, + const int G, + const int K, + const int HxW, + const T* dY, + const T* X, + const T* mu, + const T* rsig, + T* dgamma, + T* dbeta) { + __shared__ + typename BlockReduce2D::TempStorage dg_storage; + __shared__ + typename BlockReduce2D::TempStorage db_storage; + const int C = G * K; + const int c = blockIdx.x; + const int g = c / K; + T dg_val = 0; + T db_val = 0; + for (int i = threadIdx.x; i < N; i += blockDim.x) { + for (int j = threadIdx.y; j < HxW; j += blockDim.y) { + const int index = (i * HxW + j) * C + c; + const int ng = i * G + g; +#if __CUDA_ARCH__ >= 350 + dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) * + __ldg(rsig + ng); + db_val += __ldg(dY + index); +#else + dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng]; + db_val += dY[index]; +#endif } - __syncthreads(); + } + dg_val = BlockReduce2D(dg_storage).Sum(dg_val); + db_val = BlockReduce2D(db_storage).Sum(db_val); + if (threadIdx.x == 0 && threadIdx.y == 0) { + dgamma[c] = dg_val; + dbeta[c] = db_val; } } @@ -260,7 +369,7 @@ template <> void GroupNormOp::ComputeFusedParams( const int N, const int G, - const int D, + const int K, const float* mu, const float* rsig, const float* gamma, @@ -268,10 +377,8 @@ void GroupNormOp::ComputeFusedParams( float* scale, float* bias) { ComputeFusedParamsCUDAKernel - <<>>(N, G, D, mu, rsig, gamma, beta, scale, bias); + <<>>( + G, K, mu, rsig, gamma, beta, scale, bias); } template <> @@ -283,11 +390,10 @@ void GroupNormOp::GroupNormForwardNCHW( const float* scale, const float* bias, float* Y) { - GroupNormForwardCUDAKernel - <<>>(N, C, HxW, X, scale, bias, Y); + const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormForwardNCHWCUDAKernel + <<>>( + W, HxW, X, scale, bias, Y); } template <> @@ -299,11 +405,9 @@ void GroupNormOp::GroupNormForwardNHWC( const float* scale, const float* bias, float* Y) { - GroupNormForwardCUDAKernel - <<>>(N, C, HxW, X, scale, bias, Y); + GroupNormForwardNHWCCUDAKernel + <<>>( + C, HxW, X, scale, bias, Y); } // Math: @@ -314,7 +418,7 @@ template <> bool GroupNormGradientOp::RunOnDeviceImpl( const int N, const int G, - const int D, + const int K, const int HxW, const float* dY_data, const float* X_data, @@ -324,34 +428,26 @@ bool GroupNormGradientOp::RunOnDeviceImpl( float* dX_data, float* dgamma_data, float* dbeta_data) { - const int size = N * G * D * HxW; - const int C = G * D; - ReinitializeTensor( - &ds_, {N, G}, at::dtype().device(CUDA)); - ReinitializeTensor( - &db_, {N, G}, at::dtype().device(CUDA)); + const int C = G * K; + ReinitializeTensor(&ds_, {N, G}, at::dtype().device(CUDA)); + ReinitializeTensor(&db_, {N, G}, at::dtype().device(CUDA)); float* ds_data = ds_.mutable_data(); float* db_data = db_.mutable_data(); if (order_ == StorageOrder::NCHW) { // Computes dL/ds and dL/db. // dL/ds = Sum(dL/dY * gamma * X) // dL/db = Sum(dL/dY * gamma) - ComputeInternalGradientsCUDAKernel - <<>>( - N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + ComputeInternalGradientsNCHWCUDAKernel + <<>>( + G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data); // Computes dL/dX. - GroupNormBackwardCUDAKernel - <<>>( - size, + const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormBackwardNCHWCUDAKernel + <<>>( G, - D, + K, + W, HxW, dY_data, X_data, @@ -363,41 +459,89 @@ bool GroupNormGradientOp::RunOnDeviceImpl( dX_data); // Computes dL/dgamma and dL/dbeta. - GammaBetaBackwardCUDAKernel - <<>>( - N, - G, - D, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - dgamma_data, - dbeta_data); + if (HxW >= 128) { + GammaBetaBackwardNCHWCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else if (HxW >= 64) { + GammaBetaBackwardNCHWCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else if (HxW >= 32) { + GammaBetaBackwardNCHWCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else { + GammaBetaBackwardNCHWCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } } else { // Computes dL/ds and dL/db. // dL/ds = Sum(dL/dY * gamma * X) // dL/db = Sum(dL/dY * gamma) - ComputeInternalGradientsCUDAKernel - <<>>( - N, G, D, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + if (K >= 128) { + ComputeInternalGradientsNHWCCUDAKernel + <<>>( + G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + } else if (K >= 64) { + ComputeInternalGradientsNHWCCUDAKernel + <<>>( + G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + } else if (K >= 32) { + ComputeInternalGradientsNHWCCUDAKernel + <<>>( + G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + } else { + ComputeInternalGradientsNHWCCUDAKernel + <<>>( + G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data); + } // Computes dL/dX. - GroupNormBackwardCUDAKernel - << + <<>>( - size, G, - D, + K, HxW, dY_data, X_data, @@ -409,21 +553,59 @@ bool GroupNormGradientOp::RunOnDeviceImpl( dX_data); // Computes dL/dgamma and dL/dbeta. - GammaBetaBackwardCUDAKernel - <<>>( - N, - G, - D, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - dgamma_data, - dbeta_data); + if (HxW >= 128) { + GammaBetaBackwardNHWCCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else if (HxW >= 64) { + GammaBetaBackwardNHWCCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else if (HxW >= 32) { + GammaBetaBackwardNHWCCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } else { + GammaBetaBackwardNHWCCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_data, + X_data, + mu_data, + rsig_data, + dgamma_data, + dbeta_data); + } } return true; } -- 2.7.4