From: Xiaomeng Yang Date: Thu, 21 Mar 2019 19:56:20 +0000 (-0700) Subject: Optimize group_norm_op (#17945) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~701 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=43a5c636e2915dfd1a19ec73861cc17f61e5ca1e;p=platform%2Fupstream%2Fpytorch.git Optimize group_norm_op (#17945) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17945 Optimize group_norm_op Reviewed By: houseroad Differential Revision: D14419908 fbshipit-source-id: 4024b5c5dbeff97f4f026d61fc44af1f0e98ed68 --- diff --git a/caffe2/operators/group_norm_op.cc b/caffe2/operators/group_norm_op.cc index ad37aa3..80c0152 100644 --- a/caffe2/operators/group_norm_op.cc +++ b/caffe2/operators/group_norm_op.cc @@ -8,85 +8,376 @@ #include "caffe2/operators/group_norm_op.h" -namespace caffe2 { - -namespace { +#include "caffe2/utils/eigen_utils.h" -template -void ComputeInternalGradients( - const std::array& dims, - const T* dY, - const T* X, - const T* gamma, - T* ds, - T* db) { - constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2; - constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3; - const int size = dims[0] * dims[1] * dims[2] * dims[3]; - std::array index = {0, 0, 0, 0}; - for (int i = 0; i < size; ++i) { - const int i_mu = index[0] * dims[kGDim] + index[kGDim]; - const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim]; - ds[i_mu] += gamma[i_gamma] * dY[i] * X[i]; - db[i_mu] += gamma[i_gamma] * dY[i]; - math::utils::IncreaseIndexInDims(4, dims.data(), index.data()); - } -} +namespace caffe2 { // Math: // Y = gamma * (X - mu) * rsig + beta // let s = gamma * rsig // let b = beta - mu * rsig // Y = s * X + b -// let n = D * HxW +// let n = K * HxW // dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX) // d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX // db/dX = -u * drsig/dX - rsig * dmu/dX // drsig/dX = -rsig^3 * (X - mu) / n // dmu/dX = 1 / n + +namespace { + template -void GroupNormBackward( - const std::array& dims, +void ComputeInternalGradients( + int N, + int C, + int HxW, const T* dY, const T* X, + T* ds, + T* db); + +template <> +void ComputeInternalGradients( + const int N, + const int C, + const int HxW, + const float* dY, + const float* X, + float* ds, + float* db) { + ConstEigenArrayMap dY_arr(dY, HxW, N * C); + ConstEigenArrayMap X_arr(X, HxW, N * C); + for (int i = 0; i < N * C; ++i) { + ds[i] = (dY_arr.col(i) * X_arr.col(i)).sum(); + db[i] = dY_arr.col(i).sum(); + } +} + +template <> +void ComputeInternalGradients( + const int N, + const int C, + const int HxW, + const float* dY, + const float* X, + float* ds, + float* db) { + EigenArrayMap ds_arr(ds, C, N); + EigenArrayMap db_arr(db, C, N); + for (int i = 0; i < N; ++i) { + ConstEigenArrayMap dY_arr(dY + i * C * HxW, C, HxW); + ConstEigenArrayMap X_arr(X + i * C * HxW, C, HxW); + ds_arr.col(i) = dY_arr.col(0) * X_arr.col(0); + db_arr.col(i) = dY_arr.col(0); + for (int j = 1; j < HxW; ++j) { + ds_arr.col(i) += dY_arr.col(j) * X_arr.col(j); + db_arr.col(i) += dY_arr.col(j); + } + } +} + +template +void ComputeGradientFusedParams( + const int N, + const int G, + const int K, + const int HxW, + const T* ds, + const T* db, const T* mu, const T* rsig, const T* gamma, + T* dY_scale, + T* X_scale, + T* bias) { + ConstEigenArrayMap rsig_arr(rsig, G, N); + ConstEigenArrayMap gamma_arr(gamma, K, G); + for (int i = 0; i < N; ++i) { + EigenArrayMap(dY_scale + i * G * K, K, G) = + gamma_arr.rowwise() * (rsig_arr.col(i).transpose()); + } + ConstEigenVectorArrayMap mu_arr(mu, N * G); + ConstEigenVectorArrayMap rsig_vec(rsig, N * G); + EigenVectorArrayMap X_scale_arr(X_scale, N * G); + EigenVectorArrayMap bias_arr(bias, N * G); + for (int i = 0; i < N; ++i) { + ConstEigenArrayMap ds_arr(ds + i * G * K, K, G); + ConstEigenArrayMap db_arr(db + i * G * K, K, G); + for (int j = 0; j < G; ++j) { + X_scale_arr(i * G + j) = (ds_arr.col(j) * gamma_arr.col(j)).sum(); + bias_arr(i * G + j) = (db_arr.col(j) * gamma_arr.col(j)).sum(); + } + } + const T alpha = T(1) / static_cast(K * HxW); + X_scale_arr = (bias_arr * mu_arr - X_scale_arr) * rsig_vec.cube() * alpha; + bias_arr = -X_scale_arr * mu_arr - bias_arr * rsig_vec * alpha; +} + +template +void GroupNormBackward( + int N, + int G, + int K, + int HxW, + const T* dY_scale, + const T* dY, + const T* X_scale, + const T* X, + const T* bias, + T* dX); + +template <> +void GroupNormBackward( + const int N, + const int G, + const int K, + const int HxW, + const float* dY_scale, + const float* dY, + const float* X_scale, + const float* X, + const float* bias, + float* dX) { + const int C = G * K; + ConstEigenArrayMap dY_arr(dY, HxW, N * C); + ConstEigenArrayMap X_arr(X, HxW, N * C); + EigenArrayMap dX_arr(dX, HxW, N * C); + for (int i = 0; i < N * G; ++i) { + for (int j = 0; j < K; ++j) { + const int c = i * K + j; + dX_arr.col(c) = + dY_arr.col(c) * dY_scale[c] + X_arr.col(c) * X_scale[i] + bias[i]; + } + } +} + +template <> +void GroupNormBackward( + const int N, + const int G, + const int K, + const int HxW, + const float* dY_scale, + const float* dY, + const float* X_scale, + const float* X, + const float* bias, + float* dX) { + const int C = G * K; + ConstEigenArrayMap X_scale_arr(X_scale, G, N); + ConstEigenArrayMap bias_arr(bias, G, N); + for (int n = 0; n < N; ++n) { + ConstEigenArrayMap dY_scale_arr(dY_scale + n * C, K, G); + for (int i = 0; i < HxW; ++i) { + const int m = n * HxW + i; + ConstEigenArrayMap dY_arr(dY + m * C, K, G); + ConstEigenArrayMap X_arr(X + m * C, K, G); + EigenArrayMap dX_arr(dX + m * C, K, G); + dX_arr = (dY_arr * dY_scale_arr + + X_arr.rowwise() * X_scale_arr.col(n).transpose()) + .rowwise() + + bias_arr.col(n).transpose(); + } + } +} + +template +void GammaBetaBackward( + const int N, + const int G, + const int K, const T* ds, const T* db, - T* dX, + const T* mu, + const T* rsig, T* dgamma, T* dbeta) { - constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2; - constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3; - const int size = dims[0] * dims[1] * dims[2] * dims[3]; - const int HxW = kOrder == StorageOrder::NCHW ? dims[3] : dims[1]; - const T denom = T(1) / static_cast(dims[kDDim] * HxW); - std::array index = {0, 0, 0, 0}; - for (int i = 0; i < size; ++i) { - const int i_mu = index[0] * dims[kGDim] + index[kGDim]; - const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim]; - const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) * - math::utils::Cube(rsig[i_mu]); - const T v = db[i_mu] * rsig[i_mu]; - dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom; - dgamma[i_gamma] += dY[i] * (X[i] - mu[i_mu]) * rsig[i_mu]; - dbeta[i_gamma] += dY[i]; - math::utils::IncreaseIndexInDims(4, dims.data(), index.data()); + const int C = G * K; + ConstEigenArrayMap ds0_arr(ds, K, G); + ConstEigenArrayMap db0_arr(db, K, G); + ConstEigenArrayMap mu_arr(mu, G, N); + ConstEigenArrayMap rsig_arr(rsig, G, N); + EigenArrayMap dgamma_arr(dgamma, K, G); + EigenArrayMap dbeta_arr(dbeta, K, G); + dgamma_arr = + (ds0_arr - db0_arr.rowwise() * mu_arr.col(0).transpose()).rowwise() * + rsig_arr.col(0).transpose(); + dbeta_arr = db0_arr; + for (int i = 1; i < N; ++i) { + ConstEigenArrayMap dsi_arr(ds + i * C, K, G); + ConstEigenArrayMap dbi_arr(db + i * C, K, G); + dgamma_arr += + (dsi_arr - dbi_arr.rowwise() * mu_arr.col(i).transpose()).rowwise() * + rsig_arr.col(i).transpose(); + dbeta_arr += dbi_arr; } } } // namespace +template <> +void GroupNormOp::ComputeFusedParams( + const int N, + const int G, + const int K, + const float* mu, + const float* rsig, + const float* gamma, + const float* beta, + float* scale, + float* bias) { + const int C = G * K; + ConstEigenArrayMap mu_arr(mu, G, N); + ConstEigenArrayMap rsig_arr(rsig, G, N); + ConstEigenArrayMap gamma_arr(gamma, K, G); + ConstEigenArrayMap beta_arr(beta, K, G); + for (int i = 0; i < N; ++i) { + EigenArrayMap scale_arr(scale + i * C, K, G); + EigenArrayMap bias_arr(bias + i * C, K, G); + scale_arr = gamma_arr.rowwise() * rsig_arr.col(i).transpose(); + bias_arr = beta_arr - scale_arr.rowwise() * mu_arr.col(i).transpose(); + } +} + +template <> +void GroupNormOp::GroupNormForwardNCHW( + const int N, + const int C, + const int HxW, + const float* X, + const float* scale, + const float* bias, + float* Y) { + EigenArrayMap(Y, HxW, N * C) = + (ConstEigenArrayMap(X, HxW, N * C).rowwise() * + ConstEigenVectorArrayMap(scale, N * C).transpose()) + .rowwise() + + ConstEigenVectorArrayMap(bias, N * C).transpose(); +} + +template <> +void GroupNormOp::GroupNormForwardNHWC( + const int N, + const int C, + const int HxW, + const float* X, + const float* scale, + const float* bias, + float* Y) { + const int stride = HxW * C; + for (int i = 0; i < N; ++i) { + EigenArrayMap(Y + i * stride, C, HxW) = + (ConstEigenArrayMap(X + i * stride, C, HxW).colwise() * + ConstEigenVectorArrayMap(scale + i * C, C)) + .colwise() + + ConstEigenVectorArrayMap(bias + i * C, C); + } +} + +template <> +bool GroupNormOp::RunOnDeviceWithOrderNHWC( + const int N, + const int G, + const int K, + const int HxW, + const float* X, + const float* gamma, + const float* beta, + float* Y, + float* mu, + float* rsig) { + const int C = G * K; + ReinitializeTensor(&scale_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&bias_, {N, C}, at::dtype().device(CPU)); + float* scale_data = scale_.mutable_data(); + float* bias_data = bias_.mutable_data(); + EigenVectorArrayMap mu_arr(mu, N * G); + EigenVectorArrayMap rsig_arr(rsig, N * G); + mu_arr.setZero(); + rsig_arr.setZero(); + for (int n = 0; n < N; ++n) { + for (int i = 0; i < HxW; ++i) { + const int m = n * HxW + i; + ConstEigenArrayMap X_arr(X + m * C, K, G); + for (int j = 0; j < G; ++j) { + mu_arr(n * G + j) += X_arr.col(j).sum(); + rsig_arr(n * G + j) += X_arr.col(j).square().sum(); + } + } + } + const float scale = 1.0f / static_cast(K * HxW); + mu_arr *= scale; + rsig_arr = (rsig_arr * scale - mu_arr.square() + epsilon_).rsqrt(); + ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data); + GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y); + return true; +} + // Math: // let: s = gamma * rsig // let: b = beta - mu * gamma * rsig // then: Y = s * X + b +template <> +bool GroupNormGradientOp::RunOnDeviceWithOrderNCHW( + const int N, + const int G, + const int K, + const int HxW, + const float* dY_data, + const float* X_data, + const float* mu_data, + const float* rsig_data, + const float* gamma_data, + float* dX_data, + float* dgamma_data, + float* dbeta_data) { + const int C = G * K; + ReinitializeTensor(&ds_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&db_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&dY_scale_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&X_scale_, {N, G}, at::dtype().device(CPU)); + ReinitializeTensor(&bias_, {N, G}, at::dtype().device(CPU)); + float* ds_data = ds_.mutable_data(); + float* db_data = db_.mutable_data(); + float* dY_scale_data = dY_scale_.mutable_data(); + float* X_scale_data = X_scale_.mutable_data(); + float* bias_data = bias_.mutable_data(); + ComputeInternalGradients( + N, C, HxW, dY_data, X_data, ds_data, db_data); + ComputeGradientFusedParams( + N, + G, + K, + HxW, + ds_data, + db_data, + mu_data, + rsig_data, + gamma_data, + dY_scale_data, + X_scale_data, + bias_data); + GroupNormBackward( + N, + G, + K, + HxW, + dY_scale_data, + dY_data, + X_scale_data, + X_data, + bias_data, + dX_data); + GammaBetaBackward( + N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data); + return true; +} + template -bool GroupNormGradientOp::RunOnDeviceImpl( +bool GroupNormGradientOp::RunOnDeviceWithOrderNHWC( const int N, const int G, - const int D, + const int K, const int HxW, const T* dY_data, const T* X_data, @@ -96,60 +387,45 @@ bool GroupNormGradientOp::RunOnDeviceImpl( T* dX_data, T* dgamma_data, T* dbeta_data) { - const std::array dims = order_ == StorageOrder::NCHW - ? std::array{N, G, D, HxW} - : std::array{N, HxW, G, D}; - - // Computes dL/ds and dL/db. - // dL/ds = Sum(dL/dY * gamma * X) - // dL/db = Sum(dL/dY * gamma) - const int C = G * D; - ReinitializeTensor( - &ds_, {N, G}, at::dtype().device(Context::GetDeviceType())); - ReinitializeTensor( - &db_, {N, G}, at::dtype().device(Context::GetDeviceType())); - T* ds_data = ds_.template mutable_data(); - T* db_data = db_.template mutable_data(); - math::Set(N * G, T(0), ds_data, &context_); - math::Set(N * G, T(0), db_data, &context_); - if (order_ == StorageOrder::NCHW) { - ComputeInternalGradients( - dims, dY_data, X_data, gamma_data, ds_data, db_data); - } else { - ComputeInternalGradients( - dims, dY_data, X_data, gamma_data, ds_data, db_data); - } - - // Computes dL/dX, dL/dgamma and dL/dbeta. - math::Set(C, T(0), dgamma_data, &context_); - math::Set(C, T(0), dbeta_data, &context_); - if (order_ == StorageOrder::NCHW) { - GroupNormBackward( - dims, - dY_data, - X_data, - mu_data, - rsig_data, - gamma_data, - ds_data, - db_data, - dX_data, - dgamma_data, - dbeta_data); - } else { - GroupNormBackward( - dims, - dY_data, - X_data, - mu_data, - rsig_data, - gamma_data, - ds_data, - db_data, - dX_data, - dgamma_data, - dbeta_data); - } + const int C = G * K; + ReinitializeTensor(&ds_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&db_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&dY_scale_, {N, C}, at::dtype().device(CPU)); + ReinitializeTensor(&X_scale_, {N, G}, at::dtype().device(CPU)); + ReinitializeTensor(&bias_, {N, G}, at::dtype().device(CPU)); + float* ds_data = ds_.mutable_data(); + float* db_data = db_.mutable_data(); + float* dY_scale_data = dY_scale_.mutable_data(); + float* X_scale_data = X_scale_.mutable_data(); + float* bias_data = bias_.mutable_data(); + ComputeInternalGradients( + N, C, HxW, dY_data, X_data, ds_data, db_data); + ComputeGradientFusedParams( + N, + G, + K, + HxW, + ds_data, + db_data, + mu_data, + rsig_data, + gamma_data, + dY_scale_data, + X_scale_data, + bias_data); + GroupNormBackward( + N, + G, + K, + HxW, + dY_scale_data, + dY_data, + X_scale_data, + X_data, + bias_data, + dX_data); + GammaBetaBackward( + N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data); return true; } @@ -201,17 +477,21 @@ Group Normalization (GN) operation: https://arxiv.org/abs/1803.08494 // Input: dY, X, gamma, beta, mu, sig; Output: dX, dgamma, dbeta OPERATOR_SCHEMA(GroupNormGradient).NumInputs(6).NumOutputs(3); +namespace { + class GetGroupNormGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; - vector GetGradientDefs() override { + std::vector GetGradientDefs() override { return SingleGradientDef( "GroupNormGradient", "", - vector{GO(0), I(0), I(1), I(2), O(1), O(2)}, - vector{GI(0), GI(1), GI(2)}); + std::vector{GO(0), I(0), I(1), I(2), O(1), O(2)}, + std::vector{GI(0), GI(1), GI(2)}); } }; +} // namespace + REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient); } // namespace caffe2 diff --git a/caffe2/operators/group_norm_op.cu b/caffe2/operators/group_norm_op.cu index 92d56ed..3e1dcf5 100644 --- a/caffe2/operators/group_norm_op.cu +++ b/caffe2/operators/group_norm_op.cu @@ -8,9 +8,6 @@ #include "caffe2/operators/group_norm_op.h" -#include -#include - #include "caffe2/core/context_gpu.h" #include "caffe2/utils/math.h" #include "caffe2/utils/math/reduce.cuh" @@ -21,6 +18,7 @@ namespace { template __global__ void ComputeFusedParamsCUDAKernel( + const int N, const int G, const int K, const T* mu, @@ -28,28 +26,40 @@ __global__ void ComputeFusedParamsCUDAKernel( const T* gamma, const T* beta, T* scale, - T* bias) { - const int n = blockIdx.x; - const int g = blockIdx.y; - const int i_mu = n * G + g; - for (int i = threadIdx.x; i < K; i += blockDim.x) { - const int index = i_mu * K + i; - const int i_gamma = g * K + i; -#if __CUDA_ARCH__ >= 350 - const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu); + T* bias); + +template <> +__global__ void ComputeFusedParamsCUDAKernel( + const int N, + const int G, + const int K, + const float* mu, + const float* rsig, + const float* gamma, + const float* beta, + float* scale, + float* bias) { + const int C = G * K; + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C) { + const int ng = index / K; + const int c = index % C; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float scale_val = __ldg(gamma + c) * __ldg(rsig + ng); scale[index] = scale_val; - bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu); + bias[index] = fmaf(-scale_val, __ldg(mu + ng), __ldg(beta + c)); #else - const T scale_val = gamma[i_gamma] * rsig[i_mu]; + const float scale_val = gamma[c] * rsig[ng]; scale[index] = scale_val; - bias[index] = beta[i_gamma] - scale_val * mu[i_mu]; + bias[index] = fmaf(-scale_val, mu[ng], beta[c]); #endif } } -template -__global__ void GroupNormForwardNCHWCUDAKernel( - const int M, +template +__global__ void GroupNormForwardCUDAKernel( + const int N, + const int C, const int HxW, const T* X, const T* scale, @@ -57,18 +67,18 @@ __global__ void GroupNormForwardNCHWCUDAKernel( T* Y); template <> -__global__ void GroupNormForwardNCHWCUDAKernel( - const int M, +__global__ void GroupNormForwardCUDAKernel( + const int N, + const int C, const int HxW, const float* X, const float* scale, const float* bias, float* Y) { - const int nc = blockIdx.x / M; - const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (hw < HxW) { - const int index = nc * HxW + hw; -#if __CUDA_ARCH__ >= 350 + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C * HxW) { + const int nc = index / HxW; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc)); #else Y[index] = fmaf(X[index], scale[nc], bias[nc]); @@ -76,29 +86,19 @@ __global__ void GroupNormForwardNCHWCUDAKernel( } } -template -__global__ void GroupNormForwardNHWCCUDAKernel( - const int C, - const int HxW, - const T* X, - const T* scale, - const T* bias, - T* Y); - template <> -__global__ void GroupNormForwardNHWCCUDAKernel( +__global__ void GroupNormForwardCUDAKernel( + const int N, const int C, const int HxW, const float* X, const float* scale, const float* bias, float* Y) { - const int n = blockIdx.x / HxW; - const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (c < C) { - const int index = blockIdx.x * C + c; - const int nc = n * C + c; -#if __CUDA_ARCH__ >= 350 + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C * HxW) { + const int nc = index / (HxW * C) * C + index % C; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc)); #else Y[index] = fmaf(X[index], scale[nc], bias[nc]); @@ -106,84 +106,33 @@ __global__ void GroupNormForwardNHWCCUDAKernel( } } -template +template __global__ void ComputeInternalGradientsNCHWCUDAKernel( - const int G, - const int K, const int HxW, const T* dY, const T* X, - const T* gamma, T* ds, T* db) { - __shared__ - typename BlockReduce2D::TempStorage ds_storage; - __shared__ - typename BlockReduce2D::TempStorage db_storage; - const int n = blockIdx.x; - const int g = blockIdx.y; - const int ng = n * G + g; - T ds_val = 0; - T db_val = 0; - for (int i = threadIdx.x; i < K; i += blockDim.x) { - const int c = g * K + i; - for (int j = threadIdx.y; j < HxW; j += blockDim.y) { - const int index = (ng * K + i) * HxW + j; -#if __CUDA_ARCH__ >= 350 - ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index); - db_val += __ldg(gamma + c) * __ldg(dY + index); -#else - ds_val += gamma[c] * dY[index] * X[index]; - db_val += gamma[c] * dY[index]; -#endif - } - } - ds_val = BlockReduce2D(ds_storage).Sum(ds_val); - db_val = BlockReduce2D(db_storage).Sum(db_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - ds[ng] = ds_val; - db[ng] = db_val; - } -} - -template -__global__ void ComputeInternalGradientsNHWCCUDAKernel( - const int G, - const int K, - const int HxW, - const T* dY, - const T* X, - const T* gamma, - T* ds, - T* db) { - __shared__ - typename BlockReduce2D::TempStorage ds_storage; - __shared__ - typename BlockReduce2D::TempStorage db_storage; - const int C = G * K; - const int n = blockIdx.x; - const int g = blockIdx.y; - const int ng = n * G + g; - T ds_val = 0; - T db_val = 0; + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + const int nc = blockIdx.x; + T ds_sum = 0; + T db_sum = 0; for (int i = threadIdx.x; i < HxW; i += blockDim.x) { - for (int j = threadIdx.y; j < K; j += blockDim.y) { - const int c = g * K + j; - const int index = (n * HxW + i) * C + c; -#if __CUDA_ARCH__ >= 350 - ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index); - db_val += __ldg(gamma + c) * __ldg(dY + index); + const int index = nc * HxW + i; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + ds_sum += __ldg(dY + index) * __ldg(X + index); + db_sum += __ldg(dY + index); #else - ds_val += gamma[c] * dY[index] * X[index]; - db_val += gamma[c] * dY[index]; + ds_sum += dY[index] * X[index]; + db_sum += dY[index]; #endif - } } - ds_val = BlockReduce2D(ds_storage).Sum(ds_val); - db_val = BlockReduce2D(db_storage).Sum(db_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - ds[ng] = ds_val; - db[ng] = db_val; + ds_sum = BlockReduce(ds_storage).Sum(ds_sum); + db_sum = BlockReduce(db_storage).Sum(db_sum); + if (threadIdx.x == 0) { + ds[nc] = ds_sum; + db[nc] = db_sum; } } @@ -192,174 +141,212 @@ __global__ void ComputeInternalGradientsNHWCCUDAKernel( // let s = gamma * rsig // let b = beta - mu * rsig // Y = s * X + b -// let n = D * HxW +// let n = K * HxW // dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX) // d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX // db/dX = -u * drsig/dX - rsig * dmu/dX // drsig/dX = -rsig^3 * (X - mu) / n // dmu/dX = 1 / n template -__global__ void GroupNormBackwardNCHWCUDAKernel( +__global__ void ComputeYGradientScaleCUDAKernel( + const int N, const int G, const int K, - const int M, - const int HxW, - const T* dY, - const T* X, - const T* mu, const T* rsig, const T* gamma, - const T* ds, - const T* db, - T* dX) { + T* dY_scale) { const int C = G * K; - const T denom = T(1) / static_cast(K * HxW); - const int nc = blockIdx.x / M; - const int n = nc / C; - const int c = nc % C; - const int g = c / K; - const int ng = n * G + g; - const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - const int index = nc * HxW + hw; - if (hw < HxW) { -#if __CUDA_ARCH__ >= 350 - const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) * - (__ldg(X + index) - __ldg(mu + ng)) * - math::utils::Cube(__ldg(rsig + ng)); - const T v = __ldg(db + ng) * __ldg(rsig + ng); - dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) + - (u - v) * denom; + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C) { + const int ng = index / K; + const int c = index % C; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dY_scale[index] = __ldg(gamma + c) * __ldg(rsig + ng); #else - const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) * - math::utils::Cube(rsig[ng]); - const T v = db[ng] * rsig[ng]; - dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom; + dY_scale[index] = gamma[c] * rsig[ng]; #endif } } template -__global__ void GroupNormBackwardNHWCCUDAKernel( +__global__ void ComputeXScaleAndBiasCUDAKernel( const int G, const int K, - const int HxW, - const T* dY, - const T* X, + const T alpha, + const T* ds, + const T* db, const T* mu, const T* rsig, const T* gamma, - const T* ds, - const T* db, - T* dX) { - const int C = G * K; - const T denom = T(1) / static_cast(K * HxW); - const int x = blockIdx.x; + T* X_scale, + T* bias); + +template <> +__global__ void ComputeXScaleAndBiasCUDAKernel( + const int G, + const int K, + const float alpha, + const float* ds, + const float* db, + const float* mu, + const float* rsig, + const float* gamma, + float* X_scale, + float* bias) { + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + const int n = blockIdx.x; const int g = blockIdx.y; - const int n = x / HxW; const int ng = n * G + g; - const int i = blockIdx.z * CAFFE_CUDA_NUM_THREADS + threadIdx.x; - if (i < K) { + float ds_sum = 0; + float db_sum = 0; + for (int i = threadIdx.x; i < K; i += blockDim.x) { + const int index = ng * K + i; const int c = g * K + i; - const int index = x * C + c; -#if __CUDA_ARCH__ >= 350 - const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) * - (__ldg(X + index) - __ldg(mu + ng)) * - math::utils::Cube(__ldg(rsig + ng)); - const T v = __ldg(db + ng) * __ldg(rsig + ng); - dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) + - (u - v) * denom; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + ds_sum += __ldg(ds + index) * __ldg(gamma + c); + db_sum += __ldg(db + index) * __ldg(gamma + c); +#else + ds_sum += ds[index] * gamma[c]; + db_sum += db[index] * gamma[c]; +#endif + } + ds_sum = BlockReduce(ds_storage).Sum(ds_sum); + db_sum = BlockReduce(db_storage).Sum(db_sum); + if (threadIdx.x == 0) { +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + const float x = fmaf(db_sum, __ldg(mu + ng), -ds_sum) * + math::utils::Cube(__ldg(rsig + ng)) * alpha; + X_scale[ng] = x; + bias[ng] = -fmaf(x, __ldg(mu + ng), db_sum * __ldg(rsig + ng) * alpha); #else - const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) * - math::utils::Cube(rsig[ng]); - const T v = db[ng] * rsig[ng]; - dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom; + const float x = fmaf(db_sum, mu[ng], -ds_sum) * + math::utils::Cube(rsig[ng]) * alpha; + X_scale[ng] = x; + bias[ng] = -fmaf(x, mu[ng], db_sum * rsig[ng] * alpha); #endif } } -template -__global__ void GammaBetaBackwardNCHWCUDAKernel( +template +__global__ void GroupNormBackwardCUDAKernel( const int N, const int G, const int K, const int HxW, + const T* dY_scale, const T* dY, + const T* X_scale, const T* X, - const T* mu, - const T* rsig, - T* dgamma, - T* dbeta) { - __shared__ - typename BlockReduce2D::TempStorage dg_storage; - __shared__ - typename BlockReduce2D::TempStorage db_storage; + const T* bias, + T* dX); + +template <> +__global__ void GroupNormBackwardCUDAKernel( + const int N, + const int G, + const int K, + const int HxW, + const float* dY_scale, + const float* dY, + const float* X_scale, + const float* X, + const float* bias, + float* dX) { const int C = G * K; - const int c = blockIdx.x; - const int g = c / K; - T dg_val = 0; - T db_val = 0; - for (int i = threadIdx.x; i < N; i += blockDim.x) { - for (int j = threadIdx.y; j < HxW; j += blockDim.y) { - const int index = (i * C + c) * HxW + j; - const int ng = i * G + g; -#if __CUDA_ARCH__ >= 350 - dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) * - __ldg(rsig + ng); - db_val += __ldg(dY + index); + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C * HxW) { + const int nc = index / HxW; + const int ng = nc / K; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[index] = fmaf( + __ldg(dY_scale + nc), + __ldg(dY + index), + fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng))); #else - dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng]; - db_val += dY[index]; + dX[index] = + fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng])); #endif - } - } - dg_val = BlockReduce2D(dg_storage).Sum(dg_val); - db_val = BlockReduce2D(db_storage).Sum(db_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - dgamma[c] = dg_val; - dbeta[c] = db_val; } } -template -__global__ void GammaBetaBackwardNHWCCUDAKernel( +template <> +__global__ void GroupNormBackwardCUDAKernel( const int N, const int G, const int K, const int HxW, - const T* dY, - const T* X, + const float* dY_scale, + const float* dY, + const float* X_scale, + const float* X, + const float* bias, + float* dX) { + const int C = G * K; + const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; + if (index < N * C * HxW) { + const int nc = index / (HxW * C) * C + index % C; + const int ng = nc / K; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dX[index] = fmaf( + __ldg(dY_scale + nc), + __ldg(dY + index), + fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng))); +#else + dX[index] = + fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng])); +#endif + } +} + +template +__global__ void GammaBetaBackwardCUDAKernel( + const int N, + const int G, + const int K, + const T* ds, + const T* db, const T* mu, const T* rsig, T* dgamma, - T* dbeta) { - __shared__ - typename BlockReduce2D::TempStorage dg_storage; - __shared__ - typename BlockReduce2D::TempStorage db_storage; + T* dbeta); + +template <> +__global__ void GammaBetaBackwardCUDAKernel( + const int N, + const int G, + const int K, + const float* ds, + const float* db, + const float* mu, + const float* rsig, + float* dgamma, + float* dbeta) { + __shared__ typename BlockReduce::TempStorage dg_storage; + __shared__ typename BlockReduce::TempStorage db_storage; const int C = G * K; - const int c = blockIdx.x; - const int g = c / K; - T dg_val = 0; - T db_val = 0; + const int g = blockIdx.x; + const int k = blockIdx.y; + const int c = g * K + k; + float dg_sum = 0; + float db_sum = 0; for (int i = threadIdx.x; i < N; i += blockDim.x) { - for (int j = threadIdx.y; j < HxW; j += blockDim.y) { - const int index = (i * HxW + j) * C + c; - const int ng = i * G + g; -#if __CUDA_ARCH__ >= 350 - dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) * - __ldg(rsig + ng); - db_val += __ldg(dY + index); + const int nc = i * C + c; + const int ng = i * G + g; +#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__) + dg_sum += fmaf(-__ldg(db + nc), __ldg(mu + ng), __ldg(ds + nc)) * + __ldg(rsig + ng); + db_sum += __ldg(db + nc); #else - dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng]; - db_val += dY[index]; + dg_sum += fmaf(-db[nc], mu[ng], ds[nc]) * rsig[ng]; + db_sum += db[nc]; #endif - } } - dg_val = BlockReduce2D(dg_storage).Sum(dg_val); - db_val = BlockReduce2D(db_storage).Sum(db_val); - if (threadIdx.x == 0 && threadIdx.y == 0) { - dgamma[c] = dg_val; - dbeta[c] = db_val; + dg_sum = BlockReduce(dg_storage).Sum(dg_sum); + db_sum = BlockReduce(db_storage).Sum(db_sum); + if (threadIdx.x == 0) { + dgamma[c] = dg_sum; + dbeta[c] = db_sum; } } @@ -376,9 +363,10 @@ void GroupNormOp::ComputeFusedParams( const float* beta, float* scale, float* bias) { + const int M = math::DivUp(N * G * K, CAFFE_CUDA_NUM_THREADS); ComputeFusedParamsCUDAKernel - <<>>( - G, K, mu, rsig, gamma, beta, scale, bias); + <<>>( + N, G, K, mu, rsig, gamma, beta, scale, bias); } template <> @@ -390,10 +378,10 @@ void GroupNormOp::GroupNormForwardNCHW( const float* scale, const float* bias, float* Y) { - const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS); - GroupNormForwardNCHWCUDAKernel - <<>>( - M, HxW, X, scale, bias, Y); + const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormForwardCUDAKernel + <<>>( + N, C, HxW, X, scale, bias, Y); } template <> @@ -405,10 +393,10 @@ void GroupNormOp::GroupNormForwardNHWC( const float* scale, const float* bias, float* Y) { - const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS); - GroupNormForwardNHWCCUDAKernel - <<>>( - C, HxW, X, scale, bias, Y); + const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormForwardCUDAKernel + <<>>( + N, C, HxW, X, scale, bias, Y); } // Math: @@ -416,7 +404,7 @@ void GroupNormOp::GroupNormForwardNHWC( // let: b = beta - mu * gamma * rsig // then: Y = s * X + b template <> -bool GroupNormGradientOp::RunOnDeviceImpl( +bool GroupNormGradientOp::RunOnDeviceWithOrderNCHW( const int N, const int G, const int K, @@ -430,119 +418,158 @@ bool GroupNormGradientOp::RunOnDeviceImpl( float* dgamma_data, float* dbeta_data) { const int C = G * K; - ReinitializeTensor(&ds_, {N, G}, at::dtype().device(CUDA)); - ReinitializeTensor(&db_, {N, G}, at::dtype().device(CUDA)); + ReinitializeTensor(&ds_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&db_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&dY_scale_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&X_scale_, {N, G}, at::dtype().device(CUDA)); + ReinitializeTensor(&bias_, {N, G}, at::dtype().device(CUDA)); float* ds_data = ds_.mutable_data(); float* db_data = db_.mutable_data(); - if (order_ == StorageOrder::NCHW) { - // Computes dL/ds and dL/db. - // dL/ds = Sum(dL/dY * gamma * X) - // dL/db = Sum(dL/dY * gamma) - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - HxW, - ComputeInternalGradientsNCHWCUDAKernel, - float, - dim3(N, G), - context_.cuda_stream(), - G, - K, - HxW, - dY_data, - X_data, - gamma_data, - ds_data, - db_data); - - // Computes dL/dX. - const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS); - GroupNormBackwardNCHWCUDAKernel - <<>>( - G, - K, - M, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - gamma_data, - ds_data, - db_data, - dX_data); - - // Computes dL/dgamma and dL/dbeta. - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - HxW, - GammaBetaBackwardNCHWCUDAKernel, - float, - C, - context_.cuda_stream(), - N, - G, - K, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - dgamma_data, - dbeta_data); - } else { - // Computes dL/ds and dL/db. - // dL/ds = Sum(dL/dY * gamma * X) - // dL/db = Sum(dL/dY * gamma) - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - K, - ComputeInternalGradientsNHWCCUDAKernel, - float, - dim3(N, G), - context_.cuda_stream(), - G, - K, - HxW, - dY_data, - X_data, - gamma_data, - ds_data, - db_data); - - // Computes dL/dX. - const int M = math::DivUp(K, CAFFE_CUDA_NUM_THREADS); - GroupNormBackwardNHWCCUDAKernel - <<>>( - G, - K, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - gamma_data, - ds_data, - db_data, - dX_data); - - // Computes dL/dgamma and dL/dbeta. - DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1( - HxW, - GammaBetaBackwardNHWCCUDAKernel, - float, - C, - context_.cuda_stream(), - N, - G, - K, - HxW, - dY_data, - X_data, - mu_data, - rsig_data, - dgamma_data, - dbeta_data); - } + float* dY_scale_data = dY_scale_.mutable_data(); + float* X_scale_data = X_scale_.mutable_data(); + float* bias_data = bias_.mutable_data(); + + ComputeInternalGradientsNCHWCUDAKernel + <<>>( + HxW, dY_data, X_data, ds_data, db_data); + + // Computes dL/dX. + int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS); + ComputeYGradientScaleCUDAKernel + <<>>( + N, G, K, rsig_data, gamma_data, dY_scale_data); + ComputeXScaleAndBiasCUDAKernel + <<>>( + G, + K, + 1.0f / static_cast(K * HxW), + ds_data, + db_data, + mu_data, + rsig_data, + gamma_data, + X_scale_data, + bias_data); + M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormBackwardCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_scale_data, + dY_data, + X_scale_data, + X_data, + bias_data, + dX_data); + + // Computes dL/dgamma and dL/dbeta. + GammaBetaBackwardCUDAKernel< + float><<>>( + N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data); + return true; +} + +template <> +bool GroupNormGradientOp::RunOnDeviceWithOrderNHWC( + const int N, + const int G, + const int K, + const int HxW, + const float* dY_data, + const float* X_data, + const float* mu_data, + const float* rsig_data, + const float* gamma_data, + float* dX_data, + float* dgamma_data, + float* dbeta_data) { + const int C = G * K; + ReinitializeTensor(&ds_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&db_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&dY_scale_, {N, C}, at::dtype().device(CUDA)); + ReinitializeTensor(&X_scale_, {N, G}, at::dtype().device(CUDA)); + ReinitializeTensor(&bias_, {N, G}, at::dtype().device(CUDA)); + ReinitializeTensor(&ones_, {HxW}, at::dtype().device(CUDA)); + float* ds_data = ds_.mutable_data(); + float* db_data = db_.mutable_data(); + float* dY_scale_data = dY_scale_.mutable_data(); + float* X_scale_data = X_scale_.mutable_data(); + float* bias_data = bias_.mutable_data(); + float* ones_data = ones_.mutable_data(); + + math::Set(HxW, 1.0f, ones_data, &context_); + math::Mul( + N * C * HxW, dY_data, X_data, dX_data, &context_); + math::GemmStridedBatched( + CblasTrans, + CblasNoTrans, + N, + C, + 1, + HxW, + 1.0f, + dX_data, + C * HxW, + ones_data, + 0, + 0.0f, + ds_data, + C, + &context_); + math::GemmStridedBatched( + CblasTrans, + CblasNoTrans, + N, + C, + 1, + HxW, + 1.0f, + dY_data, + C * HxW, + ones_data, + 0, + 0.0f, + db_data, + C, + &context_); + + // Computes dL/dX. + int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS); + ComputeYGradientScaleCUDAKernel + <<>>( + N, G, K, rsig_data, gamma_data, dY_scale_data); + ComputeXScaleAndBiasCUDAKernel + <<>>( + G, + K, + 1.0f / static_cast(K * HxW), + ds_data, + db_data, + mu_data, + rsig_data, + gamma_data, + X_scale_data, + bias_data); + M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS); + GroupNormBackwardCUDAKernel + <<>>( + N, + G, + K, + HxW, + dY_scale_data, + dY_data, + X_scale_data, + X_data, + bias_data, + dX_data); + + // Computes dL/dgamma and dL/dbeta. + GammaBetaBackwardCUDAKernel< + float><<>>( + N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data); return true; } diff --git a/caffe2/operators/group_norm_op.h b/caffe2/operators/group_norm_op.h index 826eb17..8143f73 100644 --- a/caffe2/operators/group_norm_op.h +++ b/caffe2/operators/group_norm_op.h @@ -8,7 +8,6 @@ #include "caffe2/core/common.h" #include "caffe2/core/context.h" #include "caffe2/core/operator.h" -#include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { @@ -47,8 +46,7 @@ class GroupNormOp final : public Operator { CAFFE_ENFORCE_EQ(gamma.numel(), C); CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; - const int D = C / G; - + const int K = C / G; auto* Y = Output(OUTPUT, X.sizes(), at::dtype()); T* mu_data = nullptr; T* rsig_data = nullptr; @@ -65,24 +63,38 @@ class GroupNormOp final : public Operator { mu_data = mu_.template mutable_data(); rsig_data = rsig_.template mutable_data(); } - return RunOnDeviceImpl( - N, - G, - D, - HxW, - X.template data(), - gamma.template data(), - beta.template data(), - Y->template mutable_data(), - mu_data, - rsig_data); + if (order_ == StorageOrder::NCHW) { + return RunOnDeviceWithOrderNCHW( + N, + G, + K, + HxW, + X.template data(), + gamma.template data(), + beta.template data(), + Y->template mutable_data(), + mu_data, + rsig_data); + } else { + return RunOnDeviceWithOrderNHWC( + N, + G, + K, + HxW, + X.template data(), + gamma.template data(), + beta.template data(), + Y->template mutable_data(), + mu_data, + rsig_data); + } } - protected: - bool RunOnDeviceImpl( + private: + bool RunOnDeviceWithOrderNCHW( const int N, const int G, - const int D, + const int K, const int HxW, const T* X, const T* gamma, @@ -90,57 +102,63 @@ class GroupNormOp final : public Operator { T* Y, T* mu, T* rsig) { - const int C = G * D; + const int C = G * K; ReinitializeTensor( &scale_, {N, C}, at::dtype().device(Context::GetDeviceType())); ReinitializeTensor( &bias_, {N, C}, at::dtype().device(Context::GetDeviceType())); T* scale_data = scale_.template mutable_data(); T* bias_data = bias_.template mutable_data(); - if (order_ == StorageOrder::NCHW) { - const std::array X_dims = {N * G, D * HxW}; - const std::array Y_dims = {N * G, 1}; - math::Moments( - 2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); - math::InvStd( - N * G, static_cast(epsilon_), rsig, rsig, &context_); - ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data); - GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y); - } else { - const std::array X_dims = {N, HxW, G, D}; - const std::array Y_dims = {N, 1, G, 1}; - math::Moments( - 4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); - math::InvStd( - N * G, static_cast(epsilon_), rsig, rsig, &context_); - ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data); - GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y); - } + const std::array X_dims = {N * G, K * HxW}; + const std::array Y_dims = {N * G, 1}; + math::Moments( + 2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); + math::InvStd( + N * G, static_cast(epsilon_), rsig, rsig, &context_); + ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data); + GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y); return true; } - void ComputeFusedParams( + bool RunOnDeviceWithOrderNHWC( const int N, const int G, - const int D, + const int K, + const int HxW, + const T* X, + const T* gamma, + const T* beta, + T* Y, + T* mu, + T* rsig) { + const int C = G * K; + ReinitializeTensor( + &scale_, {N, C}, at::dtype().device(Context::GetDeviceType())); + ReinitializeTensor( + &bias_, {N, C}, at::dtype().device(Context::GetDeviceType())); + T* scale_data = scale_.template mutable_data(); + T* bias_data = bias_.template mutable_data(); + const std::array X_dims = {N, HxW, G, K}; + const std::array Y_dims = {N, 1, G, 1}; + math::Moments( + 4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_); + math::InvStd( + N * G, static_cast(epsilon_), rsig, rsig, &context_); + ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data); + GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y); + return true; + } + + void ComputeFusedParams( + int N, + int G, + int K, const T* mu, const T* rsig, const T* gamma, const T* beta, T* scale, - T* bias) { - const int C = G * D; - ConstEigenArrayMap gamma_arr(gamma, D, G); - ConstEigenArrayMap beta_arr(beta, D, G); - for (int i = 0; i < N; ++i) { - EigenArrayMap scale_arr(scale + i * C, D, G); - scale_arr = gamma_arr.rowwise() * - ConstEigenVectorArrayMap(rsig + i * G, G).transpose(); - EigenArrayMap(bias + i * C, D, G) = beta_arr - - scale_arr.rowwise() * - ConstEigenVectorArrayMap(mu + i * G, G).transpose(); - } - } + T* bias); void GroupNormForwardNCHW( const int N, @@ -149,13 +167,7 @@ class GroupNormOp final : public Operator { const T* X, const T* scale, const T* bias, - T* Y) { - EigenArrayMap(Y, HxW, N * C) = - (ConstEigenArrayMap(X, HxW, N * C).rowwise() * - ConstEigenVectorArrayMap(scale, N * C).transpose()) - .rowwise() + - ConstEigenVectorArrayMap(bias, N * C).transpose(); - } + T* Y); void GroupNormForwardNHWC( const int N, @@ -164,16 +176,7 @@ class GroupNormOp final : public Operator { const T* X, const T* scale, const T* bias, - T* Y) { - const int stride = HxW * C; - for (int i = 0; i < N; ++i) { - EigenArrayMap(Y + i * stride, C, HxW) = - (ConstEigenArrayMap(X + i * stride, C, HxW).colwise() * - ConstEigenVectorArrayMap(scale + i * C, C)) - .colwise() + - ConstEigenVectorArrayMap(bias + i * C, C); - } - } + T* Y); const int group_; const float epsilon_; @@ -223,32 +226,61 @@ class GroupNormGradientOp final : public Operator { CAFFE_ENFORCE_EQ(gamma.numel(), C); CAFFE_ENFORCE_EQ(beta.numel(), C); const int G = group_; - const int D = C / G; - + const int K = C / G; auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype()); auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype()); auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype()); - return RunOnDeviceImpl( - N, - G, - D, - HxW, - dY.template data(), - X.template data(), - mu.template data(), - rsig.template data(), - gamma.template data(), - dX->template mutable_data(), - dgamma->template mutable_data(), - dbeta->template mutable_data()); + if (order_ == StorageOrder::NCHW) { + return RunOnDeviceWithOrderNCHW( + N, + G, + K, + HxW, + dY.template data(), + X.template data(), + mu.template data(), + rsig.template data(), + gamma.template data(), + dX->template mutable_data(), + dgamma->template mutable_data(), + dbeta->template mutable_data()); + } else { + return RunOnDeviceWithOrderNHWC( + N, + G, + K, + HxW, + dY.template data(), + X.template data(), + mu.template data(), + rsig.template data(), + gamma.template data(), + dX->template mutable_data(), + dgamma->template mutable_data(), + dbeta->template mutable_data()); + } } protected: - bool RunOnDeviceImpl( - const int N, - const int G, - const int D, - const int HxW, + bool RunOnDeviceWithOrderNCHW( + int N, + int G, + int K, + int HxW, + const T* dY_data, + const T* X_data, + const T* mu_data, + const T* rsig_data, + const T* gamma_data, + T* dX_data, + T* dgamma_data, + T* dbeta_data); + + bool RunOnDeviceWithOrderNHWC( + int N, + int G, + int K, + int HxW, const T* dY_data, const T* X_data, const T* mu_data, @@ -263,6 +295,10 @@ class GroupNormGradientOp final : public Operator { Tensor ds_; Tensor db_; + Tensor dY_scale_; + Tensor X_scale_; + Tensor bias_; + Tensor ones_; // Input: dY, X, gamma, beta, mu, inv_sig // Output: dX, dgamma, dbeta