LayerNorm Support in autodiff: (#50467)
authorjiej <jiej@nvidia.com>
Thu, 12 Aug 2021 18:03:32 +0000 (11:03 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 12 Aug 2021 18:05:53 +0000 (11:05 -0700)
Summary:
1. extend autodiff by adding entry for layer_norm in symbolic script, we now use native_layer_norm_backward
2. added backward function `layernorm_double_backward` for `native_layer_norm_backward`, preserves double backward support for LayerNorm in autodiff/ScriptModule
3. added python test to verify autodiff on layer_norm with various configuration of optional tensors; (verify the fix in https://github.com/pytorch/pytorch/issues/49430)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50467

Reviewed By: eellison

Differential Revision: D30232864

Pulled By: jansel

fbshipit-source-id: b9c33075386aff96afff7415df9f94388bfb474a

Co-authored-by: Ryan Spring <rspring@nvidia.com>
Co-authored-by: Jie <jiej@nvidia.com>
aten/src/ATen/native/cpu/layer_norm_kernel.cpp
aten/src/ATen/native/cuda/layer_norm_kernel.cu
tools/autograd/derivatives.yaml
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/csrc/jit/runtime/symbolic_script.cpp
torch/testing/_internal/jit_metaprogramming_utils.py

index d395277..a065d3b 100644 (file)
@@ -288,7 +288,7 @@ void LayerNormBackwardKernelImpl(
   AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
       "LayerNormBackwardKernelImpl", [&]() {
     LayerNormBackwardKernelImplInternal<scalar_t>(
-        dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+        dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
   });
 }
 
index 464d6b8..8039d82 100644 (file)
@@ -442,7 +442,7 @@ void LayerNormBackwardKernelImpl(
       "LayerNormBackwardKernelImpl",
       [&]() {
         LayerNormBackwardKernelImplInternal<scalar_t>(
-            dY, X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+            dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
       });
 }
 
index 3aee1ff..4224e41 100644 (file)
   save_invstd: not_implemented("native_batch_norm_backward save_invstd")
 
 - name: native_layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
-  input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_layer_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, normalized_shape, eps, grad_input_mask) : (grads[0].defined() ? native_layer_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
+  input, weight, bias: "grad.defined() ? native_layer_norm_backward(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
+
+- name: native_layer_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
+  bias: Tensor()
+  mean: not_implemented("native_layer_norm_backward mean")
+  rstd: not_implemented("native_layer_norm_backward rstd")
 
 - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, int N, int C, int HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
   input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward(grads[0].is_contiguous() ? grads[0] : grads[0].contiguous(), input.is_contiguous() ? input : input.contiguous(), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
index 8fe4820..4ac1f50 100644 (file)
@@ -3125,109 +3125,138 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
 
 }
 
-std::tuple<Tensor, Tensor, Tensor>
-infinitely_differentiable_native_layer_norm_backward(
-    const Tensor& dY,
-    const Tensor& dmean,
-    const Tensor& drstd,
-    const Tensor& X,
-    const Tensor& mean,
-    const Tensor& rstd,
+std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
+    const Tensor& input_t,
     const c10::optional<Tensor>& gamma,
+    const Tensor& ggI,
+    const Tensor& ggG,
+    const Tensor& ggB,
+    const Tensor& gO_t,
+    const Tensor& save_mean_t,
+    const Tensor& save_invstd_t,
     IntArrayRef normalized_shape,
-    double eps,
-    std::array<bool, 3> grad_input_mask) {
+    std::array<bool, 3> output_mask) {
 
   const int normalized_ndim = normalized_shape.size();
-  const auto input_shape = X.sizes();
-  const auto input_ndim = X.dim();
+  const auto input_shape = input_t.sizes();
+  const auto input_ndim = input_t.dim();
   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
   const int axis = input_ndim - normalized_ndim;
   const int64_t M =
       c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
   const int64_t N =
       c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
+  //printf("M: %ld, N: %ld", M, N);
 
-  Tensor dX;
-  Tensor dgamma;
-  Tensor dbeta;
+  auto input = input_t.reshape({M, N});
+  auto gO = gO_t.reshape({M, N});
+  auto save_mean = save_mean_t.reshape({M, 1});
+  auto save_invstd = save_invstd_t.reshape({M, 1});
 
-  const Tensor X_tensor = X.reshape({M, N});
-  const Tensor mean_tensor = mean.reshape({M, 1});
-  const Tensor rstd_tensor = rstd.reshape({M, 1});
-  const double s = 1.0 / static_cast<double>(N);
+  bool affine = isDefined(gamma);
+  Tensor gamma_expanded;
+  Tensor ggG_expanded, ggB_expanded;
+  if (affine) {
+    gamma_expanded = gamma->reshape({1, N});
+    if (ggG.defined()) {
+      ggG_expanded = ggG.reshape({1, N});
+    }
+    if (ggB.defined()) {
+      ggB_expanded = ggB.reshape({1, N});
+    }
+  } else {
+    gamma_expanded = at::ones({1}, input.options());
+  }
 
-  Tensor dY_tensor;
-  if (dY.defined()) {
-    dY_tensor = dY.reshape({M, N});
+  Tensor ggI_expanded;
+  if (ggI.defined()) {
+    ggI_expanded = ggI.reshape({M, N});
   }
 
-  if (grad_input_mask[0]) {
-    Tensor gamma_tensor;
-    if (isDefined(gamma)) {
-      gamma_tensor = gamma->reshape({1, N});
-    }
-    Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
-    Tensor var;
-    Tensor dvar;
-    if (drstd.defined()) {
-      var = ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
-      dvar = -0.5 * rstd_cube * drstd.view({M, 1});
-    }
-    Tensor ds;
-    Tensor db;
-    if (dY.defined()) {
-      ds = (isDefined(gamma) ? dY_tensor * X_tensor * gamma_tensor
-                            : dY_tensor * X_tensor)
-               .sum(1)
-               .unsqueeze_(-1);
-      db = (isDefined(gamma) ? dY_tensor * gamma_tensor : dY_tensor)
-               .sum(1)
-               .unsqueeze_(-1);
-      const Tensor& a = rstd_tensor;
-      const Tensor b = (db * mean_tensor - ds) * rstd_cube * s;
-      const Tensor c = -b * mean_tensor - db * rstd_tensor * s;
-      if (isDefined(gamma)) {
-        dX = a * dY_tensor * gamma_tensor + b * X_tensor + c;
-      } else {
-        dX = a * dY_tensor + b * X_tensor + c;
-      }
-      if (dmean.defined() && drstd.defined()) {
-        dX += var_std_mean_backward(
-            {dvar, dmean.view({M, 1})},
-            X_tensor,
-            var,
-            mean_tensor,
-            /*dim=*/IntArrayRef{1},
-            /*correction=*/0,
-            /*keepdim=*/true,
-            /*is_std=*/false);
-      }
-      dX = dX.reshape_as(X);
-    } else if (dmean.defined() && drstd.defined()) {
-      dX = var_std_mean_backward(
-               {dvar, dmean.view({M, 1})},
-               X_tensor,
-               var,
-               mean_tensor,
-               /*dim=*/IntArrayRef{1},
-               /*correction=*/0,
-               /*keepdim=*/true,
-               /*is_std=*/false)
-               .reshape_as(X);
-    }
+  // for half inputs, save_mean, save_invstd are float
+  // (ideally, we would cast everything else, but not now)
+  auto mu = save_mean.to(input.scalar_type());
+  auto input_sub_mu = input - mu;
+  auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
+  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
+  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
+
+  Tensor gI;
+  // calculate gI
+  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
+
+  if (ggI.defined()) {
+
+    auto gxhat = gO * gamma_expanded;
+    auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
+    auto gxhat_sum = gxhat.sum(1, true);
+
+    auto ggI_sum = ggI_expanded.sum(1, true);
+    auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);
+
+    auto all_sub = ((ggI_sum * gxhat_sum).div_(N)).sub_((ggI_expanded * gxhat).sum(1, true)).add_(
+                    (sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N));
+    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
+    auto gI_1t = (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
+    auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) * (ggI_sum.div(N) - ggI_expanded);
+
+    gI = (gI_0t.add_(gI_1t).add_(gI_2t));
   }
 
-  if (grad_input_mask[1] && dY.defined()) {
-    dgamma = (dY_tensor * (X_tensor - mean_tensor) * rstd_tensor)
-                 .sum(0)
-                 .reshape_as(toNonOptTensor(gamma));
+  // add contribution of gamma term to gI
+  if (affine && ggG.defined()) {
+    auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
+    auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
+    auto t2 = (input_mu_sigma2_neg_3_2 * (gO * ggG_expanded * input_sub_mu).sum(1,true)).div_(-N);
+    auto gI_G_term = t0.add_(t1).add_(t2);
+    gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
   }
-  if (grad_input_mask[2] && dY.defined()) {
-    dbeta = dY_tensor.sum(0).reshape_as(toNonOptTensor(gamma));
+
+
+  if (gI.defined()) {
+    //printf("=== computing gI\n");
+    gI = gI.reshape_as(input_t);
   }
 
-  return std::make_tuple(dX, dgamma, dbeta);
+  // this is the grad_input for the first backward function
+  auto first_bwd_fn_grad_input = [&](const Tensor& gO_local, const Tensor& gamma_local) -> Tensor {
+    auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
+    auto h1 = (N * gO_local).sub_(gO_local.sum(1,true)).sub_(
+                input_sub_mu.mul(sigma2_eps_neg_1) * (gO_local * input_sub_mu).sum(1,true));
+    return h0 * h1;
+  };
+
+  // calculate gG
+  Tensor gG;
+  if (affine && ggI.defined()) {
+    gG = first_bwd_fn_grad_input(ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
+    gG = (gO * gG).sum(0);
+    gG = gG.reshape_as(*gamma);
+  }
+
+  // calculate ggO
+  Tensor ggO;
+  // contribution of input term
+  if (ggI.defined()) {
+    ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
+  }
+  if (ggG.defined()) {
+    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
+    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
+  }
+  if (ggB.defined()) {
+    auto ggO_B_term = ggB_expanded;
+    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
+  }
+  if (ggO.defined()) {
+    ggO = ggO.expand({M, N}).reshape_as(input_t);
+  }
+
+  if (output_mask[1] && !gG.defined()) {
+    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
+  }
+
+  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
 }
 
 std::tuple<Tensor, Tensor, Tensor>
index 7f4c3ad..288ae11 100644 (file)
@@ -223,18 +223,18 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
     const Tensor & weight_);
 Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntArrayRef sizes, IntArrayRef strides, optional<int64_t> storage_offset_);
 std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask);
-std::tuple<Tensor, Tensor, Tensor>
-infinitely_differentiable_native_layer_norm_backward(
-    const Tensor& dY,
-    const Tensor& dmean,
-    const Tensor& drstd,
-    const Tensor& X,
-    const Tensor& mean,
-    const Tensor& rstd,
-    const c10::optional<Tensor>& gamma,
+std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
+    const Tensor & input,
+    const c10::optional<Tensor> & gamma,
+    const Tensor & ggI,
+    const Tensor & ggG,
+    const Tensor & ggB,
+    const Tensor & gO,
+    const Tensor & save_mean,
+    const Tensor & save_invstd,
     IntArrayRef normalized_shape,
-    double eps,
-    std::array<bool, 3> grad_input_mask);
+    std::array<bool,3> output_mask);
+
 std::tuple<Tensor, Tensor> householder_product_backward(const Tensor& grad, const Tensor& input, const Tensor& tau);
 std::tuple<Tensor, Tensor> polar_backward(
     const Tensor& grad,
index 9af1fd8..453a83c 100644 (file)
@@ -1141,64 +1141,19 @@ const std::vector<std::string> functions = {
 
             return output, backward
 
-        # disable the layernorm AD temporarily because of bug in https://github.com/pytorch/pytorch/issues/19769
-        def layer_norm_disabled(input : Tensor,
+        def layer_norm(input : Tensor,
                        normalized_shape : List[int],
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],
                        eps : float,
                        cudnn_enable : bool):
 
-            input_ndim = input.dim()
-            normalized_ndim = len(normalized_shape)
-            n = 1
-            for i in range(input_ndim - normalized_ndim):
-                n *= input.size(i)
-
-            input_reshape = input.contiguous().view(1, n, -1)
-
-            bn_out, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
-                input_reshape, None, None, None, None, True,
-                0.0, eps, cudnn_enable)
-
-            bn_out = bn_out.view(input.size())
-            if weight is not None and bias is not None:
-                output = bias.addcmul(bn_out, weight, value=1)
-            elif weight is not None:
-                output = bn_out.mul(weight)
-            elif bias is not None:
-                output = bn_out.add(bias)
-            else:
-                output = bn_out
-
-            def backward(grad_output):
-                if weight is not None and bias is not None:
-                    grad_bn_out = grad_output * weight
-                    grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
-                    grad_bias = grad_output._grad_sum_to_size(bias.size())
-                elif weight is not None:
-                    grad_bn_out = grad_output * weight
-                    grad_weight = (grad_output * bn_out)._grad_sum_to_size(weight.size())
-                    grad_bias = None
-                elif bias is not None:
-                    grad_bn_out = grad_output
-                    grad_weight= None
-                    grad_bias = grad_output._grad_sum_to_size(bias.size())
-                else:
-                    grad_bn_out = grad_output
-                    grad_weight= None
-                    grad_bias = None
-
-
-                grad_bn_out = grad_bn_out.contiguous().view(1, n, -1)
+            output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
 
-                grad_input, _, _ = torch._batch_norm_impl_index_backward(
-                    impl_idx, input_reshape, grad_bn_out, None, None, None,
-                    save1, save2, True, eps, [True, False, False], reserve)
-
-                grad_input = grad_input.view(input.size())
+            def backward(grad_output):
+                output_mask = [True, weight is not None, bias is not None]
+                grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
                 return grad_input, None, grad_weight, grad_bias, None, None
-
             return output, backward
 
         def AD_fused_dropout_backward(grad,
index 28b9eb8..a21717b 100644 (file)
@@ -113,14 +113,14 @@ nn_functional_tests = [
         '', (False, 'aten::_batch_norm_impl_index')),
     ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
     ('layer_norm', (S, S, S, S), ([5],), '',
-     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+     (True, ['aten::native_layer_norm'])),
     ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),), 'with_only_weight',
-     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+     (True, ['aten::native_layer_norm'])),
     ('layer_norm', (S, S, S, S), ([5], None, non_differentiable(torch.rand(S)),), 'with_only_bias',
-     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index'])),
+     (True, ['aten::native_layer_norm'])),
     ('layer_norm', (S, S, S, S), ([5], non_differentiable(torch.rand(S)),
                                   non_differentiable(torch.rand(S))), 'with_weight_and_bias',
-     (False, ['aten::contiguous', 'aten::_batch_norm_impl_index', 'aten::addcmul'])),
+     (True, ['aten::native_layer_norm'])),
     ('group_norm', (S, S, S), (1, torch.rand(5),),),
     ('local_response_norm', (S, S, S), (2, ),),
     ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),), '',),