AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
"LayerNormBackwardKernelImpl", [&]() {
LayerNormBackwardKernelImplInternal<scalar_t>(
- dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+ dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
});
}
"LayerNormBackwardKernelImpl",
[&]() {
LayerNormBackwardKernelImplInternal<scalar_t>(
- dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+ dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
});
}
save_invstd: not_implemented("native_batch_norm_backward save_invstd")
- name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
- input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_layer_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, normalized_shape, eps, grad_input_mask) : (grads[0].defined() ? native_layer_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
+ input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+
+- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+ input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
+ bias: Tensor()
+ mean: not_implemented("native_layer_norm_backward mean")
+ rstd: not_implemented("native_layer_norm_backward rstd")
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
}
-std::tuple<Tensor, Tensor, Tensor>
-infinitely_differentiable_native_layer_norm_backward(
- const Tensor& dY,
- const Tensor& dmean,
- const Tensor& drstd,
- const Tensor& X,
- const Tensor& mean,
- const Tensor& rstd,
+std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
+ const Tensor& input_t,
const c10::optional<Tensor>& gamma,
+ const Tensor& ggI,
+ const Tensor& ggG,
+ const Tensor& ggB,
+ const Tensor& gO_t,
+ const Tensor& save_mean_t,
+ const Tensor& save_invstd_t,
IntArrayRef normalized_shape,
- double eps,
- std::array<bool, 3> grad_input_mask) {
+ std::array<bool, 3> output_mask) {
const int normalized_ndim = normalized_shape.size();
- const auto input_shape = X.sizes();
- const auto input_ndim = X.dim();
+ const auto input_shape = input_t.sizes();
+ const auto input_ndim = input_t.dim();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
const int64_t M =
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
const int64_t N =
c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
+ //printf("M: %ld, N: %ld", M, N);
- Tensor dX;
- Tensor dgamma;
- Tensor dbeta;
+ auto input = input_t.reshape({M, N});
+ auto gO = gO_t.reshape({M, N});
+ auto save_mean = save_mean_t.reshape({M, 1});
+ auto save_invstd = save_invstd_t.reshape({M, 1});
- const Tensor X_tensor = X.reshape({M, N});
- const Tensor mean_tensor = mean.reshape({M, 1});
- const Tensor rstd_tensor = rstd.reshape({M, 1});
- const double s = 1.0 / static_cast<double>(N);
+ bool affine = isDefined(gamma);
+ Tensor gamma_expanded;
+ Tensor ggG_expanded, ggB_expanded;
+ if (affine) {
+ gamma_expanded = gamma->reshape({1, N});
+ if (ggG.defined()) {
+ ggG_expanded = ggG.reshape({1, N});
+ }
+ if (ggB.defined()) {
+ ggB_expanded = ggB.reshape({1, N});
+ }
+ } else {
+ gamma_expanded = at::ones({1}, input.options());
+ }
- Tensor dY_tensor;
- if (dY.defined()) {
- dY_tensor = dY.reshape({M, N});
+ Tensor ggI_expanded;
+ if (ggI.defined()) {
+ ggI_expanded = ggI.reshape({M, N});
}
- if (grad_input_mask[0]) {
- Tensor gamma_tensor;
- if (isDefined(gamma)) {
- gamma_tensor = gamma->reshape({1, N});
- }
- Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
- Tensor var;
- Tensor dvar;
- if (drstd.defined()) {
- var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
- dvar = -0.5 * rstd_cube * drstd.view({M, 1});
- }
- Tensor ds;
- Tensor db;
- if (dY.defined()) {
- ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
- : dY_tensor * X_tensor)
- .sum(1)
- .unsqueeze_(-1);
- db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
- .sum(1)
- .unsqueeze_(-1);
- const Tensor& a = rstd_tensor;
- const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
- const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
- if (isDefined(gamma)) {
- dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
- } else {
- dX = a * dY_tensor + b * X_tensor + c;
- }
- if (dmean.defined() && drstd.defined()) {
- dX += var_std_mean_backward(
- {dvar, dmean.view({M, 1})},
- X_tensor,
- var,
- mean_tensor,
- /*dim=*/IntArrayRef{1},
- /*correction=*/0,
- /*keepdim=*/true,
- /*is_std=*/false);
- }
- dX = dX.reshape_as(X);
- } else if (dmean.defined() && drstd.defined()) {
- dX = var_std_mean_backward(
- {dvar, dmean.view({M, 1})},
- X_tensor,
- var,
- mean_tensor,
- /*dim=*/IntArrayRef{1},
- /*correction=*/0,
- /*keepdim=*/true,
- /*is_std=*/false)
- .reshape_as(X);
- }
+ // for half inputs, save_mean, save_invstd are float
+ // (ideally, we would cast everything else, but not now)
+ auto mu = save_mean.to(input.scalar_type());
+ auto input_sub_mu = input - mu;
+ auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
+ auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
+ auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
+
+ Tensor gI;
+ // calculate gI
+ auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
+
+ if (ggI.defined()) {
+
+ auto gxhat = gO * gamma_expanded;
+ auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
+ auto gxhat_sum = gxhat.sum(1, true);
+
+ auto ggI_sum = ggI_expanded.sum(1, true);
+ auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);
+
+ auto all_sub = ((ggI_sum * gxhat_sum).div_(N)).sub_((ggI_expanded * gxhat).sum(1, true)).add_(
+ (sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N));
+ auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
+ auto gI_1t = (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
+ auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * (ggI_sum.div(N) - ggI_expanded);
+
+ gI = (gI_0t.add_(gI_1t).add_(gI_2t));
}
- if (grad_input_mask[1] && dY.defined()) {
- dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
- .sum(0)
- .reshape_as(toNonOptTensor(gamma));
+ // add contribution of gamma term to gI
+ if (affine && ggG.defined()) {
+ auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
+ auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
+ auto t2 = (input_mu_sigma2_neg_3_2 * (gO * ggG_expanded * input_sub_mu).sum(1,true)).div_(-N);
+ auto gI_G_term = t0.add_(t1).add_(t2);
+ gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
}
- if (grad_input_mask[2] && dY.defined()) {
- dbeta = dY_tensor.sum(0).reshape_as(toNonOptTensor(gamma));
+
+
+ if (gI.defined()) {
+ //printf("=== computing gI\n");
+ gI = gI.reshape_as(input_t);
}
- return std::make_tuple(dX, dgamma, dbeta);
+ // this is the grad_input for the first backward function
+ auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, const Tensor& gamma_local) -> Tensor {
+ auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
+ auto h1 = (N * gO_local).sub_(gO_local.sum(1,true)).sub_(
+ input_sub_mu.mul(sigma2_eps_neg_1) * (gO_local * input_sub_mu).sum(1,true));
+ return h0 * h1;
+ };
+
+ // calculate gG
+ Tensor gG;
+ if (affine && ggI.defined()) {
+ gG = first_bwd_fn_grad_input(ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
+ gG = (gO * gG).sum(0);
+ gG = gG.reshape_as(*gamma);
+ }
+
+ // calculate ggO
+ Tensor ggO;
+ // contribution of input term
+ if (ggI.defined()) {
+ ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
+ }
+ if (ggG.defined()) {
+ auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
+ ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
+ }
+ if (ggB.defined()) {
+ auto ggO_B_term = ggB_expanded;
+ ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
+ }
+ if (ggO.defined()) {
+ ggO = ggO.expand({M, N}).reshape_as(input_t);
+ }
+
+ if (output_mask[1] && !gG.defined()) {
+ AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
+ }
+
+ return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor, Tensor>
const Tensor & weight_);
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
-std::tuple<Tensor, Tensor, Tensor>
-infinitely_differentiable_native_layer_norm_backward(
- const Tensor& dY,
- const Tensor& dmean,
- const Tensor& drstd,
- const Tensor& X,
- const Tensor& mean,
- const Tensor& rstd,
- const c10::optional<Tensor>& gamma,
+std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
+ const Tensor & input,
+ const c10::optional<Tensor> & gamma,
+ const Tensor & ggI,
+ const Tensor & ggG,
+ const Tensor & ggB,
+ const Tensor & gO,
+ const Tensor & save_mean,
+ const Tensor & save_invstd,
IntArrayRef normalized_shape,
- double eps,
- std::array<bool, 3> grad_input_mask);
+ std::array<bool,3> output_mask);
+
std::tuple<Tensor, Tensor> householder_product_backward(const Tensor& grad, const Tensor& input, const Tensor& tau);
std::tuple<Tensor, Tensor> polar_backward(
const Tensor& grad,
return output, backward
- # disable the layernorm AD temporarily because of bug in https://github.com/pytorch/pytorch/issues/19769
- def layer_norm_disabled(input : Tensor,
+ def layer_norm(input : Tensor,
normalized_shape : List[int],
weight : Optional[Tensor],
bias : Optional[Tensor],
eps : float,
cudnn_enable : bool):
- input_ndim = input.dim()
- normalized_ndim = len(normalized_shape)
- n = 1
- for i in range(input_ndim - normalized_ndim):
- n *= input.size(i)
-
- input_reshape = input.contiguous().view(1, n, -1)
-
- bn_out, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
- input_reshape, None, None, None, None, True,
- 0.0, eps, cudnn_enable)
-
- bn_out = bn_out.view(input.size())
- if weight is not None and bias is not None:
- output = bias.addcmul(bn_out, weight, value=1)
- elif weight is not None:
- output = bn_out.mul(weight)
- elif bias is not None:
- output = bn_out.add(bias)
- else:
- output = bn_out
-
- def backward(grad_output):
- if weight is not None and bias is not None:
- grad_bn_out = grad_output * weight
- grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
- grad_bias = grad_output._grad_sum_to_size(bias.size())
- elif weight is not None:
- grad_bn_out = grad_output * weight
- grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
- grad_bias = None
- elif bias is not None:
- grad_bn_out = grad_output
- grad_weight= None
- grad_bias = grad_output._grad_sum_to_size(bias.size())
- else:
- grad_bn_out = grad_output
- grad_weight= None
- grad_bias = None
-
-
- grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
+ output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
- grad_input, _, _ = torch._batch_norm_impl_index_backward(
- impl_idx, input_reshape, grad_bn_out, None, None, None,
- save1, save2, True, eps, [True, False, False], reserve)
-
- grad_input = grad_input.view(input.size())
+ def backward(grad_output):
+ output_mask = [True, weight is not None, bias is not None]
+ grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
return grad_input, None, grad_weight, grad_bias, None, None
-
return output, backward
def AD_fused_dropout_backward(grad,
'', (False, 'aten::_batch_norm_impl_index')),
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
('layer_norm', (S, S, S, S), ([5],), '',
- (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+ (True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
- (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+ (True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
- (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+ (True, ['aten::native_layer_norm'])),
('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
non_differentiable(torch.rand(S))), 'with_weight_and_bias',
- (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
+ (True, ['aten::native_layer_norm'])),
('group_norm', (S, S, S), (1, torch.rand(5),),),
('local_response_norm', (S, S, S), (2, ),),
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),