Optimize group_norm_op (#17945)
authorXiaomeng Yang <yangxm@fb.com>
Thu, 21 Mar 2019 19:56:20 +0000 (12:56 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Mar 2019 20:05:01 +0000 (13:05 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17945

Optimize group_norm_op

Reviewed By: houseroad

Differential Revision: D14419908

fbshipit-source-id: 4024b5c5dbeff97f4f026d61fc44af1f0e98ed68

caffe2/operators/group_norm_op.cc
caffe2/operators/group_norm_op.cu
caffe2/operators/group_norm_op.h

index ad37aa3..80c0152 100644 (file)
 
 #include "caffe2/operators/group_norm_op.h"
 
-namespace caffe2 {
-
-namespace {
+#include "caffe2/utils/eigen_utils.h"
 
-template <typename T, StorageOrder kOrder>
-void ComputeInternalGradients(
-    const std::array<int, 4>& dims,
-    const T* dY,
-    const T* X,
-    const T* gamma,
-    T* ds,
-    T* db) {
-  constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
-  constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
-  const int size = dims[0] * dims[1] * dims[2] * dims[3];
-  std::array<int, 4> index = {0, 0, 0, 0};
-  for (int i = 0; i < size; ++i) {
-    const int i_mu = index[0] * dims[kGDim] + index[kGDim];
-    const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
-    ds[i_mu] += gamma[i_gamma] * dY[i] * X[i];
-    db[i_mu] += gamma[i_gamma] * dY[i];
-    math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
-  }
-}
+namespace caffe2 {
 
 // Math:
 // Y = gamma * (X - mu) * rsig + beta
 // let s = gamma * rsig
 // let b = beta - mu * rsig
 // Y = s * X + b
-// let n = D * HxW
+// let n = K * HxW
 // dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
 // d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
 // db/dX = -u * drsig/dX - rsig * dmu/dX
 // drsig/dX = -rsig^3 * (X - mu) / n
 // dmu/dX = 1 / n
+
+namespace {
+
 template <typename T, StorageOrder kOrder>
-void GroupNormBackward(
-    const std::array<int, 4>& dims,
+void ComputeInternalGradients(
+    int N,
+    int C,
+    int HxW,
     const T* dY,
     const T* X,
+    T* ds,
+    T* db);
+
+template <>
+void ComputeInternalGradients<float, StorageOrder::NCHW>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* dY,
+    const float* X,
+    float* ds,
+    float* db) {
+  ConstEigenArrayMap<float> dY_arr(dY, HxW, N * C);
+  ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
+  for (int i = 0; i < N * C; ++i) {
+    ds[i] = (dY_arr.col(i) * X_arr.col(i)).sum();
+    db[i] = dY_arr.col(i).sum();
+  }
+}
+
+template <>
+void ComputeInternalGradients<float, StorageOrder::NHWC>(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* dY,
+    const float* X,
+    float* ds,
+    float* db) {
+  EigenArrayMap<float> ds_arr(ds, C, N);
+  EigenArrayMap<float> db_arr(db, C, N);
+  for (int i = 0; i < N; ++i) {
+    ConstEigenArrayMap<float> dY_arr(dY + i * C * HxW, C, HxW);
+    ConstEigenArrayMap<float> X_arr(X + i * C * HxW, C, HxW);
+    ds_arr.col(i) = dY_arr.col(0) * X_arr.col(0);
+    db_arr.col(i) = dY_arr.col(0);
+    for (int j = 1; j < HxW; ++j) {
+      ds_arr.col(i) += dY_arr.col(j) * X_arr.col(j);
+      db_arr.col(i) += dY_arr.col(j);
+    }
+  }
+}
+
+template <typename T>
+void ComputeGradientFusedParams(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const T* ds,
+    const T* db,
     const T* mu,
     const T* rsig,
     const T* gamma,
+    T* dY_scale,
+    T* X_scale,
+    T* bias) {
+  ConstEigenArrayMap<T> rsig_arr(rsig, G, N);
+  ConstEigenArrayMap<T> gamma_arr(gamma, K, G);
+  for (int i = 0; i < N; ++i) {
+    EigenArrayMap<T>(dY_scale + i * G * K, K, G) =
+        gamma_arr.rowwise() * (rsig_arr.col(i).transpose());
+  }
+  ConstEigenVectorArrayMap<T> mu_arr(mu, N * G);
+  ConstEigenVectorArrayMap<T> rsig_vec(rsig, N * G);
+  EigenVectorArrayMap<T> X_scale_arr(X_scale, N * G);
+  EigenVectorArrayMap<T> bias_arr(bias, N * G);
+  for (int i = 0; i < N; ++i) {
+    ConstEigenArrayMap<T> ds_arr(ds + i * G * K, K, G);
+    ConstEigenArrayMap<T> db_arr(db + i * G * K, K, G);
+    for (int j = 0; j < G; ++j) {
+      X_scale_arr(i * G + j) = (ds_arr.col(j) * gamma_arr.col(j)).sum();
+      bias_arr(i * G + j) = (db_arr.col(j) * gamma_arr.col(j)).sum();
+    }
+  }
+  const T alpha = T(1) / static_cast<T>(K * HxW);
+  X_scale_arr = (bias_arr * mu_arr - X_scale_arr) * rsig_vec.cube() * alpha;
+  bias_arr = -X_scale_arr * mu_arr - bias_arr * rsig_vec * alpha;
+}
+
+template <typename T, StorageOrder kOrder>
+void GroupNormBackward(
+    int N,
+    int G,
+    int K,
+    int HxW,
+    const T* dY_scale,
+    const T* dY,
+    const T* X_scale,
+    const T* X,
+    const T* bias,
+    T* dX);
+
+template <>
+void GroupNormBackward<float, StorageOrder::NCHW>(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* dY_scale,
+    const float* dY,
+    const float* X_scale,
+    const float* X,
+    const float* bias,
+    float* dX) {
+  const int C = G * K;
+  ConstEigenArrayMap<float> dY_arr(dY, HxW, N * C);
+  ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
+  EigenArrayMap<float> dX_arr(dX, HxW, N * C);
+  for (int i = 0; i < N * G; ++i) {
+    for (int j = 0; j < K; ++j) {
+      const int c = i * K + j;
+      dX_arr.col(c) =
+          dY_arr.col(c) * dY_scale[c] + X_arr.col(c) * X_scale[i] + bias[i];
+    }
+  }
+}
+
+template <>
+void GroupNormBackward<float, StorageOrder::NHWC>(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* dY_scale,
+    const float* dY,
+    const float* X_scale,
+    const float* X,
+    const float* bias,
+    float* dX) {
+  const int C = G * K;
+  ConstEigenArrayMap<float> X_scale_arr(X_scale, G, N);
+  ConstEigenArrayMap<float> bias_arr(bias, G, N);
+  for (int n = 0; n < N; ++n) {
+    ConstEigenArrayMap<float> dY_scale_arr(dY_scale + n * C, K, G);
+    for (int i = 0; i < HxW; ++i) {
+      const int m = n * HxW + i;
+      ConstEigenArrayMap<float> dY_arr(dY + m * C, K, G);
+      ConstEigenArrayMap<float> X_arr(X + m * C, K, G);
+      EigenArrayMap<float> dX_arr(dX + m * C, K, G);
+      dX_arr = (dY_arr * dY_scale_arr +
+                X_arr.rowwise() * X_scale_arr.col(n).transpose())
+                   .rowwise() +
+          bias_arr.col(n).transpose();
+    }
+  }
+}
+
+template <typename T>
+void GammaBetaBackward(
+    const int N,
+    const int G,
+    const int K,
     const T* ds,
     const T* db,
-    T* dX,
+    const T* mu,
+    const T* rsig,
     T* dgamma,
     T* dbeta) {
-  constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
-  constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
-  const int size = dims[0] * dims[1] * dims[2] * dims[3];
-  const int HxW = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
-  const T denom = T(1) / static_cast<T>(dims[kDDim] * HxW);
-  std::array<int, 4> index = {0, 0, 0, 0};
-  for (int i = 0; i < size; ++i) {
-    const int i_mu = index[0] * dims[kGDim] + index[kGDim];
-    const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
-    const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
-        math::utils::Cube(rsig[i_mu]);
-    const T v = db[i_mu] * rsig[i_mu];
-    dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
-    dgamma[i_gamma] += dY[i] * (X[i] - mu[i_mu]) * rsig[i_mu];
-    dbeta[i_gamma] += dY[i];
-    math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
+  const int C = G * K;
+  ConstEigenArrayMap<T> ds0_arr(ds, K, G);
+  ConstEigenArrayMap<T> db0_arr(db, K, G);
+  ConstEigenArrayMap<T> mu_arr(mu, G, N);
+  ConstEigenArrayMap<T> rsig_arr(rsig, G, N);
+  EigenArrayMap<T> dgamma_arr(dgamma, K, G);
+  EigenArrayMap<T> dbeta_arr(dbeta, K, G);
+  dgamma_arr =
+      (ds0_arr - db0_arr.rowwise() * mu_arr.col(0).transpose()).rowwise() *
+      rsig_arr.col(0).transpose();
+  dbeta_arr = db0_arr;
+  for (int i = 1; i < N; ++i) {
+    ConstEigenArrayMap<T> dsi_arr(ds + i * C, K, G);
+    ConstEigenArrayMap<T> dbi_arr(db + i * C, K, G);
+    dgamma_arr +=
+        (dsi_arr - dbi_arr.rowwise() * mu_arr.col(i).transpose()).rowwise() *
+        rsig_arr.col(i).transpose();
+    dbeta_arr += dbi_arr;
   }
 }
 
 } // namespace
 
+template <>
+void GroupNormOp<float, CPUContext>::ComputeFusedParams(
+    const int N,
+    const int G,
+    const int K,
+    const float* mu,
+    const float* rsig,
+    const float* gamma,
+    const float* beta,
+    float* scale,
+    float* bias) {
+  const int C = G * K;
+  ConstEigenArrayMap<float> mu_arr(mu, G, N);
+  ConstEigenArrayMap<float> rsig_arr(rsig, G, N);
+  ConstEigenArrayMap<float> gamma_arr(gamma, K, G);
+  ConstEigenArrayMap<float> beta_arr(beta, K, G);
+  for (int i = 0; i < N; ++i) {
+    EigenArrayMap<float> scale_arr(scale + i * C, K, G);
+    EigenArrayMap<float> bias_arr(bias + i * C, K, G);
+    scale_arr = gamma_arr.rowwise() * rsig_arr.col(i).transpose();
+    bias_arr = beta_arr - scale_arr.rowwise() * mu_arr.col(i).transpose();
+  }
+}
+
+template <>
+void GroupNormOp<float, CPUContext>::GroupNormForwardNCHW(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    const float* scale,
+    const float* bias,
+    float* Y) {
+  EigenArrayMap<float>(Y, HxW, N * C) =
+      (ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
+       ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
+          .rowwise() +
+      ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
+}
+
+template <>
+void GroupNormOp<float, CPUContext>::GroupNormForwardNHWC(
+    const int N,
+    const int C,
+    const int HxW,
+    const float* X,
+    const float* scale,
+    const float* bias,
+    float* Y) {
+  const int stride = HxW * C;
+  for (int i = 0; i < N; ++i) {
+    EigenArrayMap<float>(Y + i * stride, C, HxW) =
+        (ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
+         ConstEigenVectorArrayMap<float>(scale + i * C, C))
+            .colwise() +
+        ConstEigenVectorArrayMap<float>(bias + i * C, C);
+  }
+}
+
+template <>
+bool GroupNormOp<float, CPUContext>::RunOnDeviceWithOrderNHWC(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* X,
+    const float* gamma,
+    const float* beta,
+    float* Y,
+    float* mu,
+    float* rsig) {
+  const int C = G * K;
+  ReinitializeTensor(&scale_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&bias_, {N, C}, at::dtype<float>().device(CPU));
+  float* scale_data = scale_.mutable_data<float>();
+  float* bias_data = bias_.mutable_data<float>();
+  EigenVectorArrayMap<float> mu_arr(mu, N * G);
+  EigenVectorArrayMap<float> rsig_arr(rsig, N * G);
+  mu_arr.setZero();
+  rsig_arr.setZero();
+  for (int n = 0; n < N; ++n) {
+    for (int i = 0; i < HxW; ++i) {
+      const int m = n * HxW + i;
+      ConstEigenArrayMap<float> X_arr(X + m * C, K, G);
+      for (int j = 0; j < G; ++j) {
+        mu_arr(n * G + j) += X_arr.col(j).sum();
+        rsig_arr(n * G + j) += X_arr.col(j).square().sum();
+      }
+    }
+  }
+  const float scale = 1.0f / static_cast<float>(K * HxW);
+  mu_arr *= scale;
+  rsig_arr = (rsig_arr * scale - mu_arr.square() + epsilon_).rsqrt();
+  ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
+  GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
+  return true;
+}
+
 // Math:
 // let: s = gamma * rsig
 // let: b = beta - mu * gamma * rsig
 // then: Y = s * X + b
+template <>
+bool GroupNormGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* dY_data,
+    const float* X_data,
+    const float* mu_data,
+    const float* rsig_data,
+    const float* gamma_data,
+    float* dX_data,
+    float* dgamma_data,
+    float* dbeta_data) {
+  const int C = G * K;
+  ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CPU));
+  float* ds_data = ds_.mutable_data<float>();
+  float* db_data = db_.mutable_data<float>();
+  float* dY_scale_data = dY_scale_.mutable_data<float>();
+  float* X_scale_data = X_scale_.mutable_data<float>();
+  float* bias_data = bias_.mutable_data<float>();
+  ComputeInternalGradients<float, StorageOrder::NCHW>(
+      N, C, HxW, dY_data, X_data, ds_data, db_data);
+  ComputeGradientFusedParams<float>(
+      N,
+      G,
+      K,
+      HxW,
+      ds_data,
+      db_data,
+      mu_data,
+      rsig_data,
+      gamma_data,
+      dY_scale_data,
+      X_scale_data,
+      bias_data);
+  GroupNormBackward<float, StorageOrder::NCHW>(
+      N,
+      G,
+      K,
+      HxW,
+      dY_scale_data,
+      dY_data,
+      X_scale_data,
+      X_data,
+      bias_data,
+      dX_data);
+  GammaBetaBackward<float>(
+      N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
+  return true;
+}
+
 template <typename T, class Context>
-bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
+bool GroupNormGradientOp<T, Context>::RunOnDeviceWithOrderNHWC(
     const int N,
     const int G,
-    const int D,
+    const int K,
     const int HxW,
     const T* dY_data,
     const T* X_data,
@@ -96,60 +387,45 @@ bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
     T* dX_data,
     T* dgamma_data,
     T* dbeta_data) {
-  const std::array<int, 4> dims = order_ == StorageOrder::NCHW
-      ? std::array<int, 4>{N, G, D, HxW}
-      : std::array<int, 4>{N, HxW, G, D};
-
-  // Computes dL/ds and dL/db.
-  // dL/ds = Sum(dL/dY * gamma * X)
-  // dL/db = Sum(dL/dY * gamma)
-  const int C = G * D;
-  ReinitializeTensor(
-      &ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
-  ReinitializeTensor(
-      &db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
-  T* ds_data = ds_.template mutable_data<T>();
-  T* db_data = db_.template mutable_data<T>();
-  math::Set<T, Context>(N * G, T(0), ds_data, &context_);
-  math::Set<T, Context>(N * G, T(0), db_data, &context_);
-  if (order_ == StorageOrder::NCHW) {
-    ComputeInternalGradients<T, StorageOrder::NCHW>(
-        dims, dY_data, X_data, gamma_data, ds_data, db_data);
-  } else {
-    ComputeInternalGradients<T, StorageOrder::NHWC>(
-        dims, dY_data, X_data, gamma_data, ds_data, db_data);
-  }
-
-  // Computes dL/dX, dL/dgamma and dL/dbeta.
-  math::Set<T, Context>(C, T(0), dgamma_data, &context_);
-  math::Set<T, Context>(C, T(0), dbeta_data, &context_);
-  if (order_ == StorageOrder::NCHW) {
-    GroupNormBackward<T, StorageOrder::NCHW>(
-        dims,
-        dY_data,
-        X_data,
-        mu_data,
-        rsig_data,
-        gamma_data,
-        ds_data,
-        db_data,
-        dX_data,
-        dgamma_data,
-        dbeta_data);
-  } else {
-    GroupNormBackward<T, StorageOrder::NHWC>(
-        dims,
-        dY_data,
-        X_data,
-        mu_data,
-        rsig_data,
-        gamma_data,
-        ds_data,
-        db_data,
-        dX_data,
-        dgamma_data,
-        dbeta_data);
-  }
+  const int C = G * K;
+  ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CPU));
+  ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CPU));
+  float* ds_data = ds_.mutable_data<float>();
+  float* db_data = db_.mutable_data<float>();
+  float* dY_scale_data = dY_scale_.mutable_data<float>();
+  float* X_scale_data = X_scale_.mutable_data<float>();
+  float* bias_data = bias_.mutable_data<float>();
+  ComputeInternalGradients<float, StorageOrder::NHWC>(
+      N, C, HxW, dY_data, X_data, ds_data, db_data);
+  ComputeGradientFusedParams<float>(
+      N,
+      G,
+      K,
+      HxW,
+      ds_data,
+      db_data,
+      mu_data,
+      rsig_data,
+      gamma_data,
+      dY_scale_data,
+      X_scale_data,
+      bias_data);
+  GroupNormBackward<float, StorageOrder::NHWC>(
+      N,
+      G,
+      K,
+      HxW,
+      dY_scale_data,
+      dY_data,
+      X_scale_data,
+      X_data,
+      bias_data,
+      dX_data);
+  GammaBetaBackward<float>(
+      N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
   return true;
 }
 
@@ -201,17 +477,21 @@ Group Normalization (GN) operation: https://arxiv.org/abs/1803.08494
 // Input: dY, X, gamma, beta, mu, sig; Output: dX, dgamma, dbeta
 OPERATOR_SCHEMA(GroupNormGradient).NumInputs(6).NumOutputs(3);
 
+namespace {
+
 class GetGroupNormGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
-  vector<OperatorDef> GetGradientDefs() override {
+  std::vector<OperatorDef> GetGradientDefs() override {
     return SingleGradientDef(
         "GroupNormGradient",
         "",
-        vector<string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
-        vector<string>{GI(0), GI(1), GI(2)});
+        std::vector<std::string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
+        std::vector<std::string>{GI(0), GI(1), GI(2)});
   }
 };
 
+} // namespace
+
 REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient);
 
 } // namespace caffe2
index 92d56ed..3e1dcf5 100644 (file)
@@ -8,9 +8,6 @@
 
 #include "caffe2/operators/group_norm_op.h"
 
-#include <cub/block/block_reduce.cuh>
-#include <cub/cub.cuh>
-
 #include "caffe2/core/context_gpu.h"
 #include "caffe2/utils/math.h"
 #include "caffe2/utils/math/reduce.cuh"
@@ -21,6 +18,7 @@ namespace {
 
 template <typename T>
 __global__ void ComputeFusedParamsCUDAKernel(
+    const int N,
     const int G,
     const int K,
     const T* mu,
@@ -28,28 +26,40 @@ __global__ void ComputeFusedParamsCUDAKernel(
     const T* gamma,
     const T* beta,
     T* scale,
-    T* bias) {
-  const int n = blockIdx.x;
-  const int g = blockIdx.y;
-  const int i_mu = n * G + g;
-  for (int i = threadIdx.x; i < K; i += blockDim.x) {
-    const int index = i_mu * K + i;
-    const int i_gamma = g * K + i;
-#if __CUDA_ARCH__ >= 350
-    const T scale_val = __ldg(gamma + i_gamma) * __ldg(rsig + i_mu);
+    T* bias);
+
+template <>
+__global__ void ComputeFusedParamsCUDAKernel<float>(
+    const int N,
+    const int G,
+    const int K,
+    const float* mu,
+    const float* rsig,
+    const float* gamma,
+    const float* beta,
+    float* scale,
+    float* bias) {
+  const int C = G * K;
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C) {
+    const int ng = index / K;
+    const int c = index % C;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    const float scale_val = __ldg(gamma + c) * __ldg(rsig + ng);
     scale[index] = scale_val;
-    bias[index] = __ldg(beta + i_gamma) - scale_val * __ldg(mu + i_mu);
+    bias[index] = fmaf(-scale_val, __ldg(mu + ng), __ldg(beta + c));
 #else
-    const T scale_val = gamma[i_gamma] * rsig[i_mu];
+    const float scale_val = gamma[c] * rsig[ng];
     scale[index] = scale_val;
-    bias[index] = beta[i_gamma] - scale_val * mu[i_mu];
+    bias[index] = fmaf(-scale_val, mu[ng], beta[c]);
 #endif
   }
 }
 
-template <typename T>
-__global__ void GroupNormForwardNCHWCUDAKernel(
-    const int M,
+template <typename T, StorageOrder kOrder>
+__global__ void GroupNormForwardCUDAKernel(
+    const int N,
+    const int C,
     const int HxW,
     const T* X,
     const T* scale,
@@ -57,18 +67,18 @@ __global__ void GroupNormForwardNCHWCUDAKernel(
     T* Y);
 
 template <>
-__global__ void GroupNormForwardNCHWCUDAKernel<float>(
-    const int M,
+__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>(
+    const int N,
+    const int C,
     const int HxW,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int nc = blockIdx.x / M;
-  const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (hw < HxW) {
-    const int index = nc * HxW + hw;
-#if __CUDA_ARCH__ >= 350
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C * HxW) {
+    const int nc = index / HxW;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
     Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
 #else
     Y[index] = fmaf(X[index], scale[nc], bias[nc]);
@@ -76,29 +86,19 @@ __global__ void GroupNormForwardNCHWCUDAKernel<float>(
   }
 }
 
-template <typename T>
-__global__ void GroupNormForwardNHWCCUDAKernel(
-    const int C,
-    const int HxW,
-    const T* X,
-    const T* scale,
-    const T* bias,
-    T* Y);
-
 template <>
-__global__ void GroupNormForwardNHWCCUDAKernel<float>(
+__global__ void GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>(
+    const int N,
     const int C,
     const int HxW,
     const float* X,
     const float* scale,
     const float* bias,
     float* Y) {
-  const int n = blockIdx.x / HxW;
-  const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (c < C) {
-    const int index = blockIdx.x * C + c;
-    const int nc = n * C + c;
-#if __CUDA_ARCH__ >= 350
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C * HxW) {
+    const int nc = index / (HxW * C) * C + index % C;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
     Y[index] = fmaf(__ldg(X + index), __ldg(scale + nc), __ldg(bias + nc));
 #else
     Y[index] = fmaf(X[index], scale[nc], bias[nc]);
@@ -106,84 +106,33 @@ __global__ void GroupNormForwardNHWCCUDAKernel<float>(
   }
 }
 
-template <typename T, int kBlockDimX, int kBlockDimY>
+template <typename T>
 __global__ void ComputeInternalGradientsNCHWCUDAKernel(
-    const int G,
-    const int K,
     const int HxW,
     const T* dY,
     const T* X,
-    const T* gamma,
     T* ds,
     T* db) {
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
-  const int n = blockIdx.x;
-  const int g = blockIdx.y;
-  const int ng = n * G + g;
-  T ds_val = 0;
-  T db_val = 0;
-  for (int i = threadIdx.x; i < K; i += blockDim.x) {
-    const int c = g * K + i;
-    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
-      const int index = (ng * K + i) * HxW + j;
-#if __CUDA_ARCH__ >= 350
-      ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
-      db_val += __ldg(gamma + c) * __ldg(dY + index);
-#else
-      ds_val += gamma[c] * dY[index] * X[index];
-      db_val += gamma[c] * dY[index];
-#endif
-    }
-  }
-  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0 && threadIdx.y == 0) {
-    ds[ng] = ds_val;
-    db[ng] = db_val;
-  }
-}
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-__global__ void ComputeInternalGradientsNHWCCUDAKernel(
-    const int G,
-    const int K,
-    const int HxW,
-    const T* dY,
-    const T* X,
-    const T* gamma,
-    T* ds,
-    T* db) {
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage ds_storage;
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
-  const int C = G * K;
-  const int n = blockIdx.x;
-  const int g = blockIdx.y;
-  const int ng = n * G + g;
-  T ds_val = 0;
-  T db_val = 0;
+  __shared__ typename BlockReduce<T>::TempStorage ds_storage;
+  __shared__ typename BlockReduce<T>::TempStorage db_storage;
+  const int nc = blockIdx.x;
+  T ds_sum = 0;
+  T db_sum = 0;
   for (int i = threadIdx.x; i < HxW; i += blockDim.x) {
-    for (int j = threadIdx.y; j < K; j += blockDim.y) {
-      const int c = g * K + j;
-      const int index = (n * HxW + i) * C + c;
-#if __CUDA_ARCH__ >= 350
-      ds_val += __ldg(gamma + c) * __ldg(dY + index) * __ldg(X + index);
-      db_val += __ldg(gamma + c) * __ldg(dY + index);
+    const int index = nc * HxW + i;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    ds_sum += __ldg(dY + index) * __ldg(X + index);
+    db_sum += __ldg(dY + index);
 #else
-      ds_val += gamma[c] * dY[index] * X[index];
-      db_val += gamma[c] * dY[index];
+    ds_sum += dY[index] * X[index];
+    db_sum += dY[index];
 #endif
-    }
   }
-  ds_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(ds_storage).Sum(ds_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0 && threadIdx.y == 0) {
-    ds[ng] = ds_val;
-    db[ng] = db_val;
+  ds_sum = BlockReduce<T>(ds_storage).Sum(ds_sum);
+  db_sum = BlockReduce<T>(db_storage).Sum(db_sum);
+  if (threadIdx.x == 0) {
+    ds[nc] = ds_sum;
+    db[nc] = db_sum;
   }
 }
 
@@ -192,174 +141,212 @@ __global__ void ComputeInternalGradientsNHWCCUDAKernel(
 // let s = gamma * rsig
 // let b = beta - mu * rsig
 // Y = s * X + b
-// let n = D * HxW
+// let n = K * HxW
 // dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
 // d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
 // db/dX = -u * drsig/dX - rsig * dmu/dX
 // drsig/dX = -rsig^3 * (X - mu) / n
 // dmu/dX = 1 / n
 template <typename T>
-__global__ void GroupNormBackwardNCHWCUDAKernel(
+__global__ void ComputeYGradientScaleCUDAKernel(
+    const int N,
     const int G,
     const int K,
-    const int M,
-    const int HxW,
-    const T* dY,
-    const T* X,
-    const T* mu,
     const T* rsig,
     const T* gamma,
-    const T* ds,
-    const T* db,
-    T* dX) {
+    T* dY_scale) {
   const int C = G * K;
-  const T denom = T(1) / static_cast<T>(K * HxW);
-  const int nc = blockIdx.x / M;
-  const int n = nc / C;
-  const int c = nc % C;
-  const int g = c / K;
-  const int ng = n * G + g;
-  const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  const int index = nc * HxW + hw;
-  if (hw < HxW) {
-#if __CUDA_ARCH__ >= 350
-    const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
-        (__ldg(X + index) - __ldg(mu + ng)) *
-        math::utils::Cube<T>(__ldg(rsig + ng));
-    const T v = __ldg(db + ng) * __ldg(rsig + ng);
-    dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
-        (u - v) * denom;
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C) {
+    const int ng = index / K;
+    const int c = index % C;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    dY_scale[index] = __ldg(gamma + c) * __ldg(rsig + ng);
 #else
-    const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
-        math::utils::Cube<T>(rsig[ng]);
-    const T v = db[ng] * rsig[ng];
-    dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
+    dY_scale[index] = gamma[c] * rsig[ng];
 #endif
   }
 }
 
 template <typename T>
-__global__ void GroupNormBackwardNHWCCUDAKernel(
+__global__ void ComputeXScaleAndBiasCUDAKernel(
     const int G,
     const int K,
-    const int HxW,
-    const T* dY,
-    const T* X,
+    const T alpha,
+    const T* ds,
+    const T* db,
     const T* mu,
     const T* rsig,
     const T* gamma,
-    const T* ds,
-    const T* db,
-    T* dX) {
-  const int C = G * K;
-  const T denom = T(1) / static_cast<T>(K * HxW);
-  const int x = blockIdx.x;
+    T* X_scale,
+    T* bias);
+
+template <>
+__global__ void ComputeXScaleAndBiasCUDAKernel<float>(
+    const int G,
+    const int K,
+    const float alpha,
+    const float* ds,
+    const float* db,
+    const float* mu,
+    const float* rsig,
+    const float* gamma,
+    float* X_scale,
+    float* bias) {
+  __shared__ typename BlockReduce<float>::TempStorage ds_storage;
+  __shared__ typename BlockReduce<float>::TempStorage db_storage;
+  const int n = blockIdx.x;
   const int g = blockIdx.y;
-  const int n = x / HxW;
   const int ng = n * G + g;
-  const int i = blockIdx.z * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
-  if (i < K) {
+  float ds_sum = 0;
+  float db_sum = 0;
+  for (int i = threadIdx.x; i < K; i += blockDim.x) {
+    const int index = ng * K + i;
     const int c = g * K + i;
-    const int index = x * C + c;
-#if __CUDA_ARCH__ >= 350
-    const T u = (__ldg(db + ng) * __ldg(mu + ng) - __ldg(ds + ng)) *
-        (__ldg(X + index) - __ldg(mu + ng)) *
-        math::utils::Cube<T>(__ldg(rsig + ng));
-    const T v = __ldg(db + ng) * __ldg(rsig + ng);
-    dX[index] = __ldg(gamma + c) * __ldg(dY + index) * __ldg(rsig + ng) +
-        (u - v) * denom;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    ds_sum += __ldg(ds + index) * __ldg(gamma + c);
+    db_sum += __ldg(db + index) * __ldg(gamma + c);
+#else
+    ds_sum += ds[index] * gamma[c];
+    db_sum += db[index] * gamma[c];
+#endif
+  }
+  ds_sum = BlockReduce<float>(ds_storage).Sum(ds_sum);
+  db_sum = BlockReduce<float>(db_storage).Sum(db_sum);
+  if (threadIdx.x == 0) {
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    const float x = fmaf(db_sum, __ldg(mu + ng), -ds_sum) *
+        math::utils::Cube<float>(__ldg(rsig + ng)) * alpha;
+    X_scale[ng] = x;
+    bias[ng] = -fmaf(x, __ldg(mu + ng), db_sum * __ldg(rsig + ng) * alpha);
 #else
-    const T u = (db[ng] * mu[ng] - ds[ng]) * (X[index] - mu[ng]) *
-        math::utils::Cube<T>(rsig[ng]);
-    const T v = db[ng] * rsig[ng];
-    dX[index] = gamma[c] * dY[index] * rsig[ng] + (u - v) * denom;
+    const float x = fmaf(db_sum, mu[ng], -ds_sum) *
+        math::utils::Cube<float>(rsig[ng]) * alpha;
+    X_scale[ng] = x;
+    bias[ng] = -fmaf(x, mu[ng], db_sum * rsig[ng] * alpha);
 #endif
   }
 }
 
-template <typename T, int kBlockDimX, int kBlockDimY>
-__global__ void GammaBetaBackwardNCHWCUDAKernel(
+template <typename T, StorageOrder kOrder>
+__global__ void GroupNormBackwardCUDAKernel(
     const int N,
     const int G,
     const int K,
     const int HxW,
+    const T* dY_scale,
     const T* dY,
+    const T* X_scale,
     const T* X,
-    const T* mu,
-    const T* rsig,
-    T* dgamma,
-    T* dbeta) {
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+    const T* bias,
+    T* dX);
+
+template <>
+__global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* dY_scale,
+    const float* dY,
+    const float* X_scale,
+    const float* X,
+    const float* bias,
+    float* dX) {
   const int C = G * K;
-  const int c = blockIdx.x;
-  const int g = c / K;
-  T dg_val = 0;
-  T db_val = 0;
-  for (int i = threadIdx.x; i < N; i += blockDim.x) {
-    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
-      const int index = (i * C + c) * HxW + j;
-      const int ng = i * G + g;
-#if __CUDA_ARCH__ >= 350
-      dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
-          __ldg(rsig + ng);
-      db_val += __ldg(dY + index);
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C * HxW) {
+    const int nc = index / HxW;
+    const int ng = nc / K;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    dX[index] = fmaf(
+        __ldg(dY_scale + nc),
+        __ldg(dY + index),
+        fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng)));
 #else
-      dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
-      db_val += dY[index];
+    dX[index] =
+        fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng]));
 #endif
-    }
-  }
-  dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0 && threadIdx.y == 0) {
-    dgamma[c] = dg_val;
-    dbeta[c] = db_val;
   }
 }
 
-template <typename T, int kBlockDimX, int kBlockDimY>
-__global__ void GammaBetaBackwardNHWCCUDAKernel(
+template <>
+__global__ void GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>(
     const int N,
     const int G,
     const int K,
     const int HxW,
-    const T* dY,
-    const T* X,
+    const float* dY_scale,
+    const float* dY,
+    const float* X_scale,
+    const float* X,
+    const float* bias,
+    float* dX) {
+  const int C = G * K;
+  const int index = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+  if (index < N * C * HxW) {
+    const int nc = index / (HxW * C) * C + index % C;
+    const int ng = nc / K;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    dX[index] = fmaf(
+        __ldg(dY_scale + nc),
+        __ldg(dY + index),
+        fmaf(__ldg(X_scale + ng), __ldg(X + index), __ldg(bias + ng)));
+#else
+    dX[index] =
+        fmaf(dY_scale[nc], dY[index], fmaf(X_scale[ng], X[index], bias[ng]));
+#endif
+  }
+}
+
+template <typename T>
+__global__ void GammaBetaBackwardCUDAKernel(
+    const int N,
+    const int G,
+    const int K,
+    const T* ds,
+    const T* db,
     const T* mu,
     const T* rsig,
     T* dgamma,
-    T* dbeta) {
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage dg_storage;
-  __shared__
-      typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage db_storage;
+    T* dbeta);
+
+template <>
+__global__ void GammaBetaBackwardCUDAKernel<float>(
+    const int N,
+    const int G,
+    const int K,
+    const float* ds,
+    const float* db,
+    const float* mu,
+    const float* rsig,
+    float* dgamma,
+    float* dbeta) {
+  __shared__ typename BlockReduce<float>::TempStorage dg_storage;
+  __shared__ typename BlockReduce<float>::TempStorage db_storage;
   const int C = G * K;
-  const int c = blockIdx.x;
-  const int g = c / K;
-  T dg_val = 0;
-  T db_val = 0;
+  const int g = blockIdx.x;
+  const int k = blockIdx.y;
+  const int c = g * K + k;
+  float dg_sum = 0;
+  float db_sum = 0;
   for (int i = threadIdx.x; i < N; i += blockDim.x) {
-    for (int j = threadIdx.y; j < HxW; j += blockDim.y) {
-      const int index = (i * HxW + j) * C + c;
-      const int ng = i * G + g;
-#if __CUDA_ARCH__ >= 350
-      dg_val += __ldg(dY + index) * (__ldg(X + index) - __ldg(mu + ng)) *
-          __ldg(rsig + ng);
-      db_val += __ldg(dY + index);
+    const int nc = i * C + c;
+    const int ng = i * G + g;
+#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
+    dg_sum += fmaf(-__ldg(db + nc), __ldg(mu + ng), __ldg(ds + nc)) *
+        __ldg(rsig + ng);
+    db_sum += __ldg(db + nc);
 #else
-      dg_val += dY[index] * (X[index] - mu[ng]) * rsig[ng];
-      db_val += dY[index];
+    dg_sum += fmaf(-db[nc], mu[ng], ds[nc]) * rsig[ng];
+    db_sum += db[nc];
 #endif
-    }
   }
-  dg_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(dg_storage).Sum(dg_val);
-  db_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(db_storage).Sum(db_val);
-  if (threadIdx.x == 0 && threadIdx.y == 0) {
-    dgamma[c] = dg_val;
-    dbeta[c] = db_val;
+  dg_sum = BlockReduce<float>(dg_storage).Sum(dg_sum);
+  db_sum = BlockReduce<float>(db_storage).Sum(db_sum);
+  if (threadIdx.x == 0) {
+    dgamma[c] = dg_sum;
+    dbeta[c] = db_sum;
   }
 }
 
@@ -376,9 +363,10 @@ void GroupNormOp<float, CUDAContext>::ComputeFusedParams(
     const float* beta,
     float* scale,
     float* bias) {
+  const int M = math::DivUp(N * G * K, CAFFE_CUDA_NUM_THREADS);
   ComputeFusedParamsCUDAKernel<float>
-      <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          G, K, mu, rsig, gamma, beta, scale, bias);
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N, G, K, mu, rsig, gamma, beta, scale, bias);
 }
 
 template <>
@@ -390,10 +378,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNCHW(
     const float* scale,
     const float* bias,
     float* Y) {
-  const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
-  GroupNormForwardNCHWCUDAKernel<float>
-      <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          M, HxW, X, scale, bias, Y);
+  const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
+  GroupNormForwardCUDAKernel<float, StorageOrder::NCHW>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N, C, HxW, X, scale, bias, Y);
 }
 
 template <>
@@ -405,10 +393,10 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
     const float* scale,
     const float* bias,
     float* Y) {
-  const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
-  GroupNormForwardNHWCCUDAKernel<float>
-      <<<dim3(N * HxW, M), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-          C, HxW, X, scale, bias, Y);
+  const int M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
+  GroupNormForwardCUDAKernel<float, StorageOrder::NHWC>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N, C, HxW, X, scale, bias, Y);
 }
 
 // Math:
@@ -416,7 +404,7 @@ void GroupNormOp<float, CUDAContext>::GroupNormForwardNHWC(
 // let: b = beta - mu * gamma * rsig
 // then: Y = s * X + b
 template <>
-bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
+bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNCHW(
     const int N,
     const int G,
     const int K,
@@ -430,119 +418,158 @@ bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceImpl(
     float* dgamma_data,
     float* dbeta_data) {
   const int C = G * K;
-  ReinitializeTensor(&ds_, {N, G}, at::dtype<float>().device(CUDA));
-  ReinitializeTensor(&db_, {N, G}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CUDA));
   float* ds_data = ds_.mutable_data<float>();
   float* db_data = db_.mutable_data<float>();
-  if (order_ == StorageOrder::NCHW) {
-    // Computes dL/ds and dL/db.
-    // dL/ds = Sum(dL/dY * gamma * X)
-    // dL/db = Sum(dL/dY * gamma)
-    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
-        HxW,
-        ComputeInternalGradientsNCHWCUDAKernel,
-        float,
-        dim3(N, G),
-        context_.cuda_stream(),
-        G,
-        K,
-        HxW,
-        dY_data,
-        X_data,
-        gamma_data,
-        ds_data,
-        db_data);
-
-    // Computes dL/dX.
-    const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
-    GroupNormBackwardNCHWCUDAKernel<float>
-        <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
-            G,
-            K,
-            M,
-            HxW,
-            dY_data,
-            X_data,
-            mu_data,
-            rsig_data,
-            gamma_data,
-            ds_data,
-            db_data,
-            dX_data);
-
-    // Computes dL/dgamma and dL/dbeta.
-    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
-        HxW,
-        GammaBetaBackwardNCHWCUDAKernel,
-        float,
-        C,
-        context_.cuda_stream(),
-        N,
-        G,
-        K,
-        HxW,
-        dY_data,
-        X_data,
-        mu_data,
-        rsig_data,
-        dgamma_data,
-        dbeta_data);
-  } else {
-    // Computes dL/ds and dL/db.
-    // dL/ds = Sum(dL/dY * gamma * X)
-    // dL/db = Sum(dL/dY * gamma)
-    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
-        K,
-        ComputeInternalGradientsNHWCCUDAKernel,
-        float,
-        dim3(N, G),
-        context_.cuda_stream(),
-        G,
-        K,
-        HxW,
-        dY_data,
-        X_data,
-        gamma_data,
-        ds_data,
-        db_data);
-
-    // Computes dL/dX.
-    const int M = math::DivUp(K, CAFFE_CUDA_NUM_THREADS);
-    GroupNormBackwardNHWCCUDAKernel<float>
-        <<<dim3(N * HxW, G, M),
-           CAFFE_CUDA_NUM_THREADS,
-           0,
-           context_.cuda_stream()>>>(
-            G,
-            K,
-            HxW,
-            dY_data,
-            X_data,
-            mu_data,
-            rsig_data,
-            gamma_data,
-            ds_data,
-            db_data,
-            dX_data);
-
-    // Computes dL/dgamma and dL/dbeta.
-    DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
-        HxW,
-        GammaBetaBackwardNHWCCUDAKernel,
-        float,
-        C,
-        context_.cuda_stream(),
-        N,
-        G,
-        K,
-        HxW,
-        dY_data,
-        X_data,
-        mu_data,
-        rsig_data,
-        dgamma_data,
-        dbeta_data);
-  }
+  float* dY_scale_data = dY_scale_.mutable_data<float>();
+  float* X_scale_data = X_scale_.mutable_data<float>();
+  float* bias_data = bias_.mutable_data<float>();
+
+  ComputeInternalGradientsNCHWCUDAKernel<float>
+      <<<N * C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          HxW, dY_data, X_data, ds_data, db_data);
+
+  // Computes dL/dX.
+  int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS);
+  ComputeYGradientScaleCUDAKernel<float>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N, G, K, rsig_data, gamma_data, dY_scale_data);
+  ComputeXScaleAndBiasCUDAKernel<float>
+      <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          G,
+          K,
+          1.0f / static_cast<float>(K * HxW),
+          ds_data,
+          db_data,
+          mu_data,
+          rsig_data,
+          gamma_data,
+          X_scale_data,
+          bias_data);
+  M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
+  GroupNormBackwardCUDAKernel<float, StorageOrder::NCHW>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N,
+          G,
+          K,
+          HxW,
+          dY_scale_data,
+          dY_data,
+          X_scale_data,
+          X_data,
+          bias_data,
+          dX_data);
+
+  // Computes dL/dgamma and dL/dbeta.
+  GammaBetaBackwardCUDAKernel<
+      float><<<dim3(G, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
+  return true;
+}
+
+template <>
+bool GroupNormGradientOp<float, CUDAContext>::RunOnDeviceWithOrderNHWC(
+    const int N,
+    const int G,
+    const int K,
+    const int HxW,
+    const float* dY_data,
+    const float* X_data,
+    const float* mu_data,
+    const float* rsig_data,
+    const float* gamma_data,
+    float* dX_data,
+    float* dgamma_data,
+    float* dbeta_data) {
+  const int C = G * K;
+  ReinitializeTensor(&ds_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&db_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&dY_scale_, {N, C}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&X_scale_, {N, G}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&bias_, {N, G}, at::dtype<float>().device(CUDA));
+  ReinitializeTensor(&ones_, {HxW}, at::dtype<float>().device(CUDA));
+  float* ds_data = ds_.mutable_data<float>();
+  float* db_data = db_.mutable_data<float>();
+  float* dY_scale_data = dY_scale_.mutable_data<float>();
+  float* X_scale_data = X_scale_.mutable_data<float>();
+  float* bias_data = bias_.mutable_data<float>();
+  float* ones_data = ones_.mutable_data<float>();
+
+  math::Set<float, CUDAContext>(HxW, 1.0f, ones_data, &context_);
+  math::Mul<float, CUDAContext>(
+      N * C * HxW, dY_data, X_data, dX_data, &context_);
+  math::GemmStridedBatched<float, CUDAContext>(
+      CblasTrans,
+      CblasNoTrans,
+      N,
+      C,
+      1,
+      HxW,
+      1.0f,
+      dX_data,
+      C * HxW,
+      ones_data,
+      0,
+      0.0f,
+      ds_data,
+      C,
+      &context_);
+  math::GemmStridedBatched<float, CUDAContext>(
+      CblasTrans,
+      CblasNoTrans,
+      N,
+      C,
+      1,
+      HxW,
+      1.0f,
+      dY_data,
+      C * HxW,
+      ones_data,
+      0,
+      0.0f,
+      db_data,
+      C,
+      &context_);
+
+  // Computes dL/dX.
+  int M = math::DivUp(N * C, CAFFE_CUDA_NUM_THREADS);
+  ComputeYGradientScaleCUDAKernel<float>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N, G, K, rsig_data, gamma_data, dY_scale_data);
+  ComputeXScaleAndBiasCUDAKernel<float>
+      <<<dim3(N, G), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          G,
+          K,
+          1.0f / static_cast<float>(K * HxW),
+          ds_data,
+          db_data,
+          mu_data,
+          rsig_data,
+          gamma_data,
+          X_scale_data,
+          bias_data);
+  M = math::DivUp(N * C * HxW, CAFFE_CUDA_NUM_THREADS);
+  GroupNormBackwardCUDAKernel<float, StorageOrder::NHWC>
+      <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+          N,
+          G,
+          K,
+          HxW,
+          dY_scale_data,
+          dY_data,
+          X_scale_data,
+          X_data,
+          bias_data,
+          dX_data);
+
+  // Computes dL/dgamma and dL/dbeta.
+  GammaBetaBackwardCUDAKernel<
+      float><<<dim3(G, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+      N, G, K, ds_data, db_data, mu_data, rsig_data, dgamma_data, dbeta_data);
   return true;
 }
 
index 826eb17..8143f73 100644 (file)
@@ -8,7 +8,6 @@
 #include "caffe2/core/common.h"
 #include "caffe2/core/context.h"
 #include "caffe2/core/operator.h"
-#include "caffe2/utils/eigen_utils.h"
 #include "caffe2/utils/math.h"
 
 namespace caffe2 {
@@ -47,8 +46,7 @@ class GroupNormOp final : public Operator<Context> {
     CAFFE_ENFORCE_EQ(gamma.numel(), C);
     CAFFE_ENFORCE_EQ(beta.numel(), C);
     const int G = group_;
-    const int D = C / G;
-
+    const int K = C / G;
     auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
     T* mu_data = nullptr;
     T* rsig_data = nullptr;
@@ -65,24 +63,38 @@ class GroupNormOp final : public Operator<Context> {
       mu_data = mu_.template mutable_data<T>();
       rsig_data = rsig_.template mutable_data<T>();
     }
-    return RunOnDeviceImpl(
-        N,
-        G,
-        D,
-        HxW,
-        X.template data<T>(),
-        gamma.template data<T>(),
-        beta.template data<T>(),
-        Y->template mutable_data<T>(),
-        mu_data,
-        rsig_data);
+    if (order_ == StorageOrder::NCHW) {
+      return RunOnDeviceWithOrderNCHW(
+          N,
+          G,
+          K,
+          HxW,
+          X.template data<T>(),
+          gamma.template data<T>(),
+          beta.template data<T>(),
+          Y->template mutable_data<T>(),
+          mu_data,
+          rsig_data);
+    } else {
+      return RunOnDeviceWithOrderNHWC(
+          N,
+          G,
+          K,
+          HxW,
+          X.template data<T>(),
+          gamma.template data<T>(),
+          beta.template data<T>(),
+          Y->template mutable_data<T>(),
+          mu_data,
+          rsig_data);
+    }
   }
 
- protected:
-  bool RunOnDeviceImpl(
+ private:
+  bool RunOnDeviceWithOrderNCHW(
       const int N,
       const int G,
-      const int D,
+      const int K,
       const int HxW,
       const T* X,
       const T* gamma,
@@ -90,57 +102,63 @@ class GroupNormOp final : public Operator<Context> {
       T* Y,
       T* mu,
       T* rsig) {
-    const int C = G * D;
+    const int C = G * K;
     ReinitializeTensor(
         &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
     ReinitializeTensor(
         &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
     T* scale_data = scale_.template mutable_data<T>();
     T* bias_data = bias_.template mutable_data<T>();
-    if (order_ == StorageOrder::NCHW) {
-      const std::array<int, 2> X_dims = {N * G, D * HxW};
-      const std::array<int, 2> Y_dims = {N * G, 1};
-      math::Moments<T, Context>(
-          2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
-      math::InvStd<T, Context>(
-          N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
-      ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
-      GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
-    } else {
-      const std::array<int, 4> X_dims = {N, HxW, G, D};
-      const std::array<int, 4> Y_dims = {N, 1, G, 1};
-      math::Moments<T, Context>(
-          4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
-      math::InvStd<T, Context>(
-          N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
-      ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
-      GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
-    }
+    const std::array<int, 2> X_dims = {N * G, K * HxW};
+    const std::array<int, 2> Y_dims = {N * G, 1};
+    math::Moments<T, Context>(
+        2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
+    math::InvStd<T, Context>(
+        N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
+    ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
+    GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
     return true;
   }
 
-  void ComputeFusedParams(
+  bool RunOnDeviceWithOrderNHWC(
       const int N,
       const int G,
-      const int D,
+      const int K,
+      const int HxW,
+      const T* X,
+      const T* gamma,
+      const T* beta,
+      T* Y,
+      T* mu,
+      T* rsig) {
+    const int C = G * K;
+    ReinitializeTensor(
+        &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
+    ReinitializeTensor(
+        &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
+    T* scale_data = scale_.template mutable_data<T>();
+    T* bias_data = bias_.template mutable_data<T>();
+    const std::array<int, 4> X_dims = {N, HxW, G, K};
+    const std::array<int, 4> Y_dims = {N, 1, G, 1};
+    math::Moments<T, Context>(
+        4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
+    math::InvStd<T, Context>(
+        N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
+    ComputeFusedParams(N, G, K, mu, rsig, gamma, beta, scale_data, bias_data);
+    GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
+    return true;
+  }
+
+  void ComputeFusedParams(
+      int N,
+      int G,
+      int K,
       const T* mu,
       const T* rsig,
       const T* gamma,
       const T* beta,
       T* scale,
-      T* bias) {
-    const int C = G * D;
-    ConstEigenArrayMap<float> gamma_arr(gamma, D, G);
-    ConstEigenArrayMap<float> beta_arr(beta, D, G);
-    for (int i = 0; i < N; ++i) {
-      EigenArrayMap<T> scale_arr(scale + i * C, D, G);
-      scale_arr = gamma_arr.rowwise() *
-          ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
-      EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
-          scale_arr.rowwise() *
-              ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
-    }
-  }
+      T* bias);
 
   void GroupNormForwardNCHW(
       const int N,
@@ -149,13 +167,7 @@ class GroupNormOp final : public Operator<Context> {
       const T* X,
       const T* scale,
       const T* bias,
-      T* Y) {
-    EigenArrayMap<float>(Y, HxW, N * C) =
-        (ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
-         ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
-            .rowwise() +
-        ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
-  }
+      T* Y);
 
   void GroupNormForwardNHWC(
       const int N,
@@ -164,16 +176,7 @@ class GroupNormOp final : public Operator<Context> {
       const T* X,
       const T* scale,
       const T* bias,
-      T* Y) {
-    const int stride = HxW * C;
-    for (int i = 0; i < N; ++i) {
-      EigenArrayMap<float>(Y + i * stride, C, HxW) =
-          (ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
-           ConstEigenVectorArrayMap<float>(scale + i * C, C))
-              .colwise() +
-          ConstEigenVectorArrayMap<float>(bias + i * C, C);
-    }
-  }
+      T* Y);
 
   const int group_;
   const float epsilon_;
@@ -223,32 +226,61 @@ class GroupNormGradientOp final : public Operator<Context> {
     CAFFE_ENFORCE_EQ(gamma.numel(), C);
     CAFFE_ENFORCE_EQ(beta.numel(), C);
     const int G = group_;
-    const int D = C / G;
-
+    const int K = C / G;
     auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
     auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
     auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
-    return RunOnDeviceImpl(
-        N,
-        G,
-        D,
-        HxW,
-        dY.template data<T>(),
-        X.template data<T>(),
-        mu.template data<T>(),
-        rsig.template data<T>(),
-        gamma.template data<T>(),
-        dX->template mutable_data<T>(),
-        dgamma->template mutable_data<T>(),
-        dbeta->template mutable_data<T>());
+    if (order_ == StorageOrder::NCHW) {
+      return RunOnDeviceWithOrderNCHW(
+          N,
+          G,
+          K,
+          HxW,
+          dY.template data<T>(),
+          X.template data<T>(),
+          mu.template data<T>(),
+          rsig.template data<T>(),
+          gamma.template data<T>(),
+          dX->template mutable_data<T>(),
+          dgamma->template mutable_data<T>(),
+          dbeta->template mutable_data<T>());
+    } else {
+      return RunOnDeviceWithOrderNHWC(
+          N,
+          G,
+          K,
+          HxW,
+          dY.template data<T>(),
+          X.template data<T>(),
+          mu.template data<T>(),
+          rsig.template data<T>(),
+          gamma.template data<T>(),
+          dX->template mutable_data<T>(),
+          dgamma->template mutable_data<T>(),
+          dbeta->template mutable_data<T>());
+    }
   }
 
  protected:
-  bool RunOnDeviceImpl(
-      const int N,
-      const int G,
-      const int D,
-      const int HxW,
+  bool RunOnDeviceWithOrderNCHW(
+      int N,
+      int G,
+      int K,
+      int HxW,
+      const T* dY_data,
+      const T* X_data,
+      const T* mu_data,
+      const T* rsig_data,
+      const T* gamma_data,
+      T* dX_data,
+      T* dgamma_data,
+      T* dbeta_data);
+
+  bool RunOnDeviceWithOrderNHWC(
+      int N,
+      int G,
+      int K,
+      int HxW,
       const T* dY_data,
       const T* X_data,
       const T* mu_data,
@@ -263,6 +295,10 @@ class GroupNormGradientOp final : public Operator<Context> {
 
   Tensor ds_;
   Tensor db_;
+  Tensor dY_scale_;
+  Tensor X_scale_;
+  Tensor bias_;
+  Tensor ones_;
 
   // Input: dY, X, gamma, beta, mu, inv_sig
   // Output: dX, dgamma, dbeta