Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19621
Comments for group_norm_op is not accurate (i.e., the math part), this diff will fix it.
Reviewed By: BIT-silence
Differential Revision:
D15048695
fbshipit-source-id:
27d41d3ae21054257967815254134849944d56ca
// Math:
// Y = gamma * (X - mu) * rsig + beta
// let s = gamma * rsig
-// let b = beta - mu * rsig
+// let b = beta - gamma * mu * rsig
// Y = s * X + b
// let n = K * HxW
// dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
// d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
-// db/dX = -u * drsig/dX - rsig * dmu/dX
+// db/dX = -gamma * u * drsig/dX - gamma * rsig * dmu/dX
// drsig/dX = -rsig^3 * (X - mu) / n
// dmu/dX = 1 / n
// Math:
// Y = gamma * (X - mu) * rsig + beta
// let s = gamma * rsig
-// let b = beta - mu * rsig
+// let b = beta - gamma * mu * rsig
// Y = s * X + b
// let n = K * HxW
// dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
// d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
-// db/dX = -u * drsig/dX - rsig * dmu/dX
+// db/dX = -gamma * u * drsig/dX - gamma * rsig * dmu/dX
// drsig/dX = -rsig^3 * (X - mu) / n
// dmu/dX = 1 / n
template <typename T>