Make `torch.lu` differentiable for wide/tall inputs + jit (#61564)
authorNikita Vedeneev <nik@quansight.com>
Mon, 16 Aug 2021 18:39:04 +0000 (11:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 18:40:57 +0000 (11:40 -0700)
Summary:
As per title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61564

Reviewed By: astaff

Differential Revision: D30338136

Pulled By: mruberry

fbshipit-source-id: f01436fc90980544cdfa270feee16bb3dda21b93

12 files changed:
aten/src/ATen/native/native_functions.yaml
test/backward_compatibility/check_backward_compatibility.py
test/test_fx.py
test/test_namedtuple_return_api.py
tools/autograd/derivatives.yaml
tools/autograd/gen_variable_type.py
torch/_autograd_functions.py [deleted file]
torch/_tensor.py
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/functional.py
torch/testing/_internal/common_methods_invocations.py

index 663e8df..40245cc 100644 (file)
   dispatch:
     CPU, CUDA: ormqr
 
-- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
+- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
   variants: function
   dispatch:
     CPU, CUDA: _lu_with_info
index 326e038..e1dde92 100644 (file)
@@ -51,6 +51,7 @@ ALLOW_LIST = [
     ("aten::_svd_helper", datetime.date(2021, 1, 31)),
     ("aten::_syevd_helper", datetime.date(9999, 1, 1)),
     ("aten::_lu_solve_helper", datetime.date(9999, 1, 1)),
+    ("aten::_lu_with_info", datetime.date(9999, 1, 1)),
     ("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),
     ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
     ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)),
index 2573572..cf69143 100644 (file)
@@ -2865,6 +2865,7 @@ class TestOperatorSignatures(JitTestCase):
                            'fill_',
                            'hstack',
                            'linalg.multi_dot',
+                           'lu',
                            'norm',
                            'polygamma',
                            'special.polygamma',
index 92152fa..345431d 100644 (file)
@@ -19,6 +19,7 @@ all_operators_with_namedtuple_return = {
     'frexp', 'lu_unpack', 'histogram', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
     '_fused_moving_avg_obs_fq_helper',
     '_det_lu_based_helper',
+    '_lu_with_info',
 }
 
 
@@ -99,6 +100,8 @@ class TestNamedTupleAPI(TestCase):
             op(operators=['_det_lu_based_helper'],
                input=(), names=('det', 'lu', 'pivs'), hasout=False),
             op(operators=['aminmax'], input=(), names=('min', 'max'), hasout=True),
+            op(operators=['_lu_with_info'],
+               input=(), names=('LU', 'pivots', 'info'), hasout=False),
         ]
 
         def get_func(f):
index 4dbe9b4..b52b690 100644 (file)
   self: zeros_like(self)
   other: zeros_like(other)
 
-- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
-  self: not_implemented("lu_with_info")
+- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
+  self: _lu_with_info_backward(grad, self, LU, pivots)
 
 - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
   self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
index f28033d..a64f734 100644 (file)
@@ -102,7 +102,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
     'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
     'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
     'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
-    'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve',
+    'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
 }
 
 GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
diff --git a/torch/_autograd_functions.py b/torch/_autograd_functions.py
deleted file mode 100644 (file)
index 1d809be..0000000
+++ /dev/null
@@ -1,93 +0,0 @@
-import torch
-
-class _LU(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, self, pivot=True, get_infos=False):
-        LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
-        ctx.save_for_backward(LU, pivots)
-        ctx.mark_non_differentiable(pivots, infos)
-        return LU, pivots, infos
-
-    @staticmethod
-    def backward(ctx, LU_grad, pivots_grad, infors_grad):
-        """
-        Here we derive the gradients for the LU decomposition.
-        LIMITATIONS: square inputs of full rank.
-        If not stated otherwise, for tensors A and B,
-        `A B` means the matrix product of A and B.
-
-        Let A^H = (A^T).conj()
-
-        Forward AD:
-        Note that PyTorch returns packed LU, it is a mapping
-        A -> (B:= L + U - I, P), such that A = P L U, and
-        P is a permutation matrix, and is non-differentiable.
-
-        Using B = L + U - I, A = P L U, we get
-
-        dB = dL + dU and     (*)
-        P^T dA = dL U + L dU (**)
-
-        By left/right multiplication of (**) with L^{-1}/U^{-1} we get:
-        L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}.
-
-        Note that L^{-1} dL is lower-triangular with zero diagonal,
-        and dU U^{-1} is upper-triangular.
-        Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so
-
-        L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}),
-        dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product.
-
-        Hence we finally get:
-        dL = L 1_L * (L^{-1} P^T dA U^{-1}),
-        dU = 1_U * (L^{-1} P^T dA U^{-1}) U
-
-        Backward AD:
-        The backward sensitivity is then:
-        Tr(B_grad^H dB) = Tr(B_grad^H dL) + Tr(B_grad^H dU) = [1] + [2].
-
-        [1] = Tr(B_grad^H dL) = Tr(B_grad^H L 1_L * (L^{-1} P^T dA U^{-1}))
-            = [using Tr(A (B * C)) = Tr((A * B^T) C)]
-            = Tr((B_grad^H L * 1_L^T) L^{-1} P^T dA U^{-1})
-            = [cyclic property of trace]
-            = Tr(U^{-1} (B_grad^H L * 1_L^T) L^{-1} P^T dA)
-            = Tr((P L^{-H} (L^H B_grad * 1_L) U^{-H})^H dA).
-        Similar, [2] can be rewritten as:
-        [2] = Tr(P L^{-H} (B_grad U^H * 1_U) U^{-H})^H dA, hence
-        Tr(A_grad^H dA) = [1] + [2]
-                        = Tr((P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H})^H dA), so
-        A_grad = P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H}.
-
-        In the code below we use the name `LU` instead of `B`, so that there is no confusion
-        in the derivation above between the matrix product and a two-letter variable name.
-        """
-        LU, pivots = ctx.saved_tensors
-        P, L, U = torch.lu_unpack(LU, pivots)
-
-        # To make sure MyPy infers types right
-        assert (L is not None) and (U is not None) and (P is not None)
-
-        # phi_L = L^H B_grad * 1_L
-        phi_L = (L.transpose(-1, -2).conj() @ LU_grad).tril_()
-        phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0)
-        # phi_U = B_grad U^H * 1_U
-        phi_U = (LU_grad @ U.transpose(-1, -2).conj()).triu_()
-        phi = phi_L + phi_U
-
-        # using the notation from above plus the variable names, note
-        # A_grad = P L^{-H} phi U^{-H}.
-        # Instead of inverting L and U, we solve two systems of equations, i.e.,
-        # the above expression could be rewritten as
-        # L^H P^T A_grad U^H = phi.
-        # Let X = P^T A_grad U_H, then
-        # X = L^{-H} phi, where L^{-H} is upper triangular, or
-        # X = torch.triangular_solve(phi, L^H)
-        # using the definition of X we see:
-        # X = P^T A_grad U_H => P X = A_grad U_H => U A_grad^H = X^H P^T, so
-        # A_grad = (U^{-1} X^H P^T)^H, or
-        # A_grad = torch.triangular_solve(X^H P^T, U)^H
-        X = torch.triangular_solve(phi, L.transpose(-1, -2).conj(), upper=True).solution
-        A_grad = torch.triangular_solve(X.transpose(-1, -2).conj() @ P.transpose(-1, -2), U, upper=True) \
-            .solution.transpose(-1, -2).conj()
-
-        return A_grad, None, None
index 24811da..2bd617d 100644 (file)
@@ -453,28 +453,6 @@ class Tensor(torch._C._TensorBase):
         if has_torch_function_unary(self):
             return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)
 
-        if not torch._jit_internal.is_scripting():
-            if self.requires_grad:
-                if not (self.size(-2) == self.size(-1) and (self.dtype.is_floating_point) or self.is_complex):
-                    raise ValueError(
-                        'lu.backward works only with batches of squared full-rank matrices'
-                        ' of floating or complex types.'
-                    )
-
-                from torch._autograd_functions import _LU
-                LU, pivots, infos = _LU.apply(self, pivot, get_infos)
-                if get_infos:
-                    return LU, pivots, infos
-                else:
-                    return LU, pivots
-        else:
-            if self.requires_grad:
-                raise RuntimeError(
-                    'Script and require gradients is not supported at the moment.'
-                    'If you just want to do the forward, use .detach()'
-                    'on the input before calling the function.'
-                )
-
         LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
         if get_infos:
             return LU, pivots, infos
index 8fc2223..86639c1 100644 (file)
@@ -3831,6 +3831,178 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten
   return out_fw_grad;
 }
 
+// Let X in \C^{m \times n}, then its pivoted LU decomposition is
+// X = P L U, where P is a permutation matrix.
+//
+// Useful notation:
+// Let o denote the elementwise, or Hadamard, product.
+// k := min(m, n)
+// 1 := ones(k, k),
+// 1_U = 1.tril();
+// 1_L = 1 - 1_U (note the diagonal is zero)
+// For a matrix A, A^H := A.transpose(-2, -1).conj()
+//
+// Below we derive the backward algorithm for the case when m <= n.
+// The case m > n could be obtained using the same idea.
+// Since we assume m <= n, the LU decomposition of X could be written as
+// X = (X1 | X2) = P L (U1 | U2) where X1, U1 in \C^{m \times m}, X2, U2 in \C^{m, n - m}
+//
+// Forward AD:
+//
+// dX = P dL U + P L dU => [left-multiply P^T]
+// (P^T dX1 | P^T dX2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
+// From (*):
+// P^T dX1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by U1^{-1}]
+// L^{-1} P^T dX1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**).
+// Note, L is lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
+// Also, since the diagonal of L (all ones) is never exposed explicity (packed representation),
+// the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
+// Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
+// Combining these observations we conclude:
+//
+// L^{-1} dL = (L^{-1} P^T dX1 U1^{-1}) o 1_L,
+// dU1 U1^{-1} = (L^{-1} P^T dX1 U1^{-1}) o 1_U.
+//
+// Hence,
+// dL = L [(L^{-1} P^T dX1 U1^{-1}) o 1_L],
+// dU1 = [(L^{-1} P^T dX1 U1^{-1}) o 1_U] U1.
+// As for dU2, from (*) it follows
+// P^T dX2 = dL U2 + L dU2 =>
+// dU2 = L^{-1} (P^T dX2 - dL U2).
+//
+// Backward AD:
+//
+// The following equality comes very handy:
+// Tr(A (B o C)) = Tr((A o B^T) C) (!)
+//
+// Tr(X_grad^H dX) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
+//
+// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dX1 U1^{-1}) o 1_L] = [using (!)]
+//                 = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dX1 U1^{-1}) = [using the cyclic property of Tr]
+//                 = Tr(U1^{-1} (L_grad^H L o 1_L^T) L^{-1} P^T dX1)
+//
+// Similar, using (!) and the cyclic property of the trace operator:
+// Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
+//                 = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dX1)
+//                 + Tr(U2_grad^H L^{-1} P^T dX2)
+//                 - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dX1)
+//
+// By combining the matrices to the left from dX1 and dX2 and then applying conjugate transposition,
+// we finally arrive at:
+//
+// X1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H},
+// X2_grad = P L^{-H} U2_grad
+Tensor plu_backward_base(
+  const variable_list& grads,
+  const Tensor& self,
+  const Tensor& P,
+  const Tensor& L,
+  const Tensor& U) {
+  auto L_grad = grads[0];
+  auto U_grad = grads[1];
+
+  auto m = self.size(-2);
+  auto n = self.size(-1);
+  auto k = std::min(m, n);
+
+  auto L_principal = L.narrow(-2, 0, k).narrow(-1, 0, k);
+  auto L_principal_H = L_principal.transpose(-2, -1).conj();
+  auto L_grad_principal = L_grad.narrow(-2, 0, k).narrow(-1, 0, k);
+  auto U_principal = U.narrow(-2, 0, k).narrow(-1, 0, k);
+  auto U_principal_H = U_principal.transpose(-2, -1).conj();
+  auto U_grad_principal = U_grad.narrow(-2, 0, k).narrow(-1, 0, k);
+
+  auto phi_L = L_principal_H.matmul(L_grad_principal).tril_(-1);
+  auto phi_U = U_grad_principal.matmul(U_principal_H).triu_();
+
+  auto phi = phi_L + phi_U;
+  auto psi = at::zeros_like(self);
+
+  Tensor self_grad;
+  if (m <= n) {
+    auto U_complement = U.narrow(-2, 0, k).narrow(-1, k, n - k);
+    auto U_grad_complement = U_grad.narrow(-2, 0, k).narrow(-1, k, n - k);
+
+    auto phi_complement = U_grad_complement.matmul(U_complement.transpose(-2, -1).conj()).tril_(-1);
+    phi.sub_(phi_complement);
+
+    // recall the result for X1_grad and X2_grad from above.
+    // It can be rewritten as
+    // (X1_grad | X2_grad) = P L^{-H} psi, where
+    // psi = (psi1 | psi2)
+    //     = ([L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H} | U2_grad),
+    // so it is filled in parts.
+    //
+    // fill psi2 in
+    psi.narrow(-2, 0, k).narrow(-1, k, n - k).copy_(U_grad_complement);
+
+    // solve for psi1 to avoid the inversion of U1^H
+    auto psi_principal = std::get<0>(at::triangular_solve(
+      phi.transpose(-2, -1).conj(),
+      U_principal,
+      /*upper=*/true,
+      /*transpose=*/false,
+      /*unitriangular=*/false
+    )).transpose(-2, -1).conj();
+    psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal);
+
+    // solve for the grad to avoid the inversion of L1^H
+    self_grad = P.matmul(
+      std::get<0>(at::triangular_solve(
+        psi,
+        L_principal_H,
+        /*upper=*/true,
+        /*transpose=*/false,
+        /*unitriangular=*/true
+      ))
+    );
+  }
+  else {
+    // variables psi and phi carry the same meaning as in the case (m <= n),
+    // albeit they are differently defined.
+    auto L_complement = L.narrow(-2, k, m - k).narrow(-1, 0, k);
+    auto L_grad_complement = L_grad.narrow(-2, k, m - k).narrow(-1, 0, k);
+
+    auto phi_complement = L_complement.transpose(-2, -1).conj().matmul(L_grad_complement).triu_();
+    phi.sub_(phi_complement);
+
+    psi.narrow(-2, k, m - k).narrow(-1, 0, k).copy_(L_grad_complement);
+
+    auto psi_principal = std::get<0>(at::triangular_solve(
+      phi,
+      L_principal_H,
+      /*upper=*/true,
+      /*transpose=*/false,
+      /*unitriangular=*/true
+    ));
+    psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal);
+
+    self_grad = std::get<0>(at::triangular_solve(
+      P.matmul(psi).transpose(-2, -1),
+      U_principal.conj(),
+      /*upper=*/true,
+      /*transpose=*/false,
+      /*unitriangular=*/false
+    )).transpose(-2, -1);
+  }
+
+  return self_grad;
+}
+
+Tensor _lu_with_info_backward(
+  const Tensor& grad,
+  const Tensor& self,
+  const Tensor& LU,
+  const Tensor& pivs) {
+  Tensor P, L, U;
+  std::tie(P, L, U) = at::lu_unpack(LU, pivs);
+  // Note that packed LU could be represented as
+  // LU = L + U - I, hence
+  // L_grad = LU_grad,
+  // U_grad = LU_grad.
+  return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U);
+}
+
 } // namespace details
 } // namespace generated
 } // namespace autograd
index 91ddec3..d397f55 100644 (file)
@@ -258,6 +258,7 @@ Tensor lu_unpack_backward(
   const Tensor& LU_data,
   bool unpack_data
 );
+
 Tensor _det_lu_based_helper_backward(
   const Tensor& det_grad,
   const Tensor& det,
@@ -266,6 +267,20 @@ Tensor _det_lu_based_helper_backward(
   const Tensor& pivs
 );
 
+Tensor lu_backward_base(
+  const variable_list& grads,
+  const Tensor& self,
+  const Tensor& P,
+  const Tensor& L,
+  const Tensor& U
+);
+Tensor _lu_with_info_backward(
+  const Tensor& grad,
+  const Tensor& self,
+  const Tensor& LU,
+  const Tensor& pivs
+);
+
 Tensor cat_jvp(at::TensorList tensors, int64_t dim);
 Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim);
 Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim);
index 8898fc9..5d74664 100644 (file)
@@ -10,7 +10,6 @@ from .overrides import (
     handle_torch_function)
 from ._jit_internal import boolean_dispatch, List
 from ._jit_internal import _overload as overload
-from torch._autograd_functions import _LU
 
 Tensor = torch.Tensor
 from torch import _VF
@@ -1459,8 +1458,10 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
         * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
 
     .. warning::
-        The LU factorization does have backward support,
-        but only for square inputs of full rank.
+        The gradients of this function will only be finite when :attr:`A` is full rank.
+        This is because the LU decomposition is just differentiable at full rank matrices.
+        Furthermore, if :attr:`A` is close to not being full rank,
+        the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
 
     Args:
         A (Tensor): the tensor to factor of size :math:`(*, m, n)`
@@ -1508,23 +1509,6 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
         ...   print('LU factorization succeeded for all samples!')
         LU factorization succeeded for all samples!
     """
-    if not torch._jit_internal.is_scripting():
-        if A.requires_grad:
-            if not (A.size(-2) == A.size(-1) and (A.dtype.is_floating_point or A.is_complex)):
-                raise ValueError(
-                    'lu.backward works only with batches of squared full-rank matrices'
-                    ' of floating or complex types.'
-                )
-
-            return _LU.apply(A, pivot, get_infos)
-    else:
-        if A.requires_grad:
-            raise RuntimeError(
-                'Script and require gradients is not supported at the moment.'
-                'If you just want to do the forward, use .detach()'
-                'on the input before calling the function.'
-            )
-
     # If get_infos is True, then we don't need to check for errors and vice versa
     return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
 
index 1c22ffa..7e923be 100644 (file)
@@ -3204,8 +3204,8 @@ def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
     # not needed once OpInfo tests support Iterables
     def generate_samples():
         batch_shapes = ((), (3,), (3, 3))
-        for batch_shape, get_infos in product(batch_shapes, (True, False)):
-            shape = batch_shape + (S, S)
+        for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)):
+            shape = batch_shape + (S + size_delta, S)
             input = make_tensor(shape, device, dtype, requires_grad=requires_grad, low=None, high=None)
             yield SampleInput(input, args=(True, get_infos))
 
@@ -6533,16 +6533,16 @@ op_db: List[OpInfo] = [
            op=torch.lu,
            dtypes=floating_and_complex_types(),
            supports_inplace_autograd=False,
+           # we use in-place operations which cannot be avoided.
+           # This causes vmap failures, hence we skip batched gradient checks
+           check_batched_grad=False,
            check_batched_gradgrad=False,
            supports_out=False,
            sample_inputs_func=sample_inputs_lu,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
            skips=(
-               # we skip jit tests because lu_backward is impelemented as autograd.Function,
-               # which does not support autograd with scripting
+               # we skip jit tests because `lu` is a torch function
                SkipInfo('TestJit', 'test_variant_consistency_jit'),
-               # Skip operator schema test because this is a functional and not an operator
-               SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
            )),
     OpInfo('lu_solve',
            op=torch.lu_solve,
@@ -6555,7 +6555,7 @@ op_db: List[OpInfo] = [
            dtypes=floating_and_complex_types(),
            supports_inplace_autograd=False,
            # we use in-place operations which cannot be avoided.
-           # This cases vmap failures, hence we skip batched gradient checks
+           # This causes vmap failures, hence we skip batched gradient checks
            check_batched_grad=False,
            supports_out=True,
            sample_inputs_func=sample_inputs_lu_unpack,