From: Nikita Vedeneev Date: Mon, 16 Aug 2021 18:39:04 +0000 (-0700) Subject: Make `torch.lu` differentiable for wide/tall inputs + jit (#61564) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~998 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dbcfd7739f30fc842b692ab0bf0ce348fee2bde1;p=platform%2Fupstream%2Fpytorch.git Make `torch.lu` differentiable for wide/tall inputs + jit (#61564) Summary: As per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61564 Reviewed By: astaff Differential Revision: D30338136 Pulled By: mruberry fbshipit-source-id: f01436fc90980544cdfa270feee16bb3dda21b93 --- diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 663e8df..40245cc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6809,7 +6809,7 @@ dispatch: CPU, CUDA: ormqr -- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor) +- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) variants: function dispatch: CPU, CUDA: _lu_with_info diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index 326e038..e1dde92 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -51,6 +51,7 @@ ALLOW_LIST = [ ("aten::_svd_helper", datetime.date(2021, 1, 31)), ("aten::_syevd_helper", datetime.date(9999, 1, 1)), ("aten::_lu_solve_helper", datetime.date(9999, 1, 1)), + ("aten::_lu_with_info", datetime.date(9999, 1, 1)), ("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)), ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)), ("aten::_cudnn_rnn", datetime.date(2020, 12, 31)), diff --git a/test/test_fx.py b/test/test_fx.py index 2573572..cf69143 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2865,6 +2865,7 @@ class TestOperatorSignatures(JitTestCase): 'fill_', 'hstack', 'linalg.multi_dot', + 'lu', 'norm', 'polygamma', 'special.polygamma', diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 92152fa..345431d 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -19,6 +19,7 @@ all_operators_with_namedtuple_return = { 'frexp', 'lu_unpack', 'histogram', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams', '_fused_moving_avg_obs_fq_helper', '_det_lu_based_helper', + '_lu_with_info', } @@ -99,6 +100,8 @@ class TestNamedTupleAPI(TestCase): op(operators=['_det_lu_based_helper'], input=(), names=('det', 'lu', 'pivs'), hasout=False), op(operators=['aminmax'], input=(), names=('min', 'max'), hasout=True), + op(operators=['_lu_with_info'], + input=(), names=('LU', 'pivots', 'info'), hasout=False), ] def get_func(f): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4dbe9b4..b52b690 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -860,8 +860,8 @@ self: zeros_like(self) other: zeros_like(other) -- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor) - self: not_implemented("lu_with_info") +- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) + self: _lu_with_info_backward(grad, self, LU, pivots) - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index f28033d..a64f734 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -102,7 +102,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = { 'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward', 'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical', 'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid', - 'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', + 'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info', } GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = { diff --git a/torch/_autograd_functions.py b/torch/_autograd_functions.py deleted file mode 100644 index 1d809be..0000000 --- a/torch/_autograd_functions.py +++ /dev/null @@ -1,93 +0,0 @@ -import torch - -class _LU(torch.autograd.Function): - @staticmethod - def forward(ctx, self, pivot=True, get_infos=False): - LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) - ctx.save_for_backward(LU, pivots) - ctx.mark_non_differentiable(pivots, infos) - return LU, pivots, infos - - @staticmethod - def backward(ctx, LU_grad, pivots_grad, infors_grad): - """ - Here we derive the gradients for the LU decomposition. - LIMITATIONS: square inputs of full rank. - If not stated otherwise, for tensors A and B, - `A B` means the matrix product of A and B. - - Let A^H = (A^T).conj() - - Forward AD: - Note that PyTorch returns packed LU, it is a mapping - A -> (B:= L + U - I, P), such that A = P L U, and - P is a permutation matrix, and is non-differentiable. - - Using B = L + U - I, A = P L U, we get - - dB = dL + dU and (*) - P^T dA = dL U + L dU (**) - - By left/right multiplication of (**) with L^{-1}/U^{-1} we get: - L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}. - - Note that L^{-1} dL is lower-triangular with zero diagonal, - and dU U^{-1} is upper-triangular. - Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so - - L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}), - dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product. - - Hence we finally get: - dL = L 1_L * (L^{-1} P^T dA U^{-1}), - dU = 1_U * (L^{-1} P^T dA U^{-1}) U - - Backward AD: - The backward sensitivity is then: - Tr(B_grad^H dB) = Tr(B_grad^H dL) + Tr(B_grad^H dU) = [1] + [2]. - - [1] = Tr(B_grad^H dL) = Tr(B_grad^H L 1_L * (L^{-1} P^T dA U^{-1})) - = [using Tr(A (B * C)) = Tr((A * B^T) C)] - = Tr((B_grad^H L * 1_L^T) L^{-1} P^T dA U^{-1}) - = [cyclic property of trace] - = Tr(U^{-1} (B_grad^H L * 1_L^T) L^{-1} P^T dA) - = Tr((P L^{-H} (L^H B_grad * 1_L) U^{-H})^H dA). - Similar, [2] can be rewritten as: - [2] = Tr(P L^{-H} (B_grad U^H * 1_U) U^{-H})^H dA, hence - Tr(A_grad^H dA) = [1] + [2] - = Tr((P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H})^H dA), so - A_grad = P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H}. - - In the code below we use the name `LU` instead of `B`, so that there is no confusion - in the derivation above between the matrix product and a two-letter variable name. - """ - LU, pivots = ctx.saved_tensors - P, L, U = torch.lu_unpack(LU, pivots) - - # To make sure MyPy infers types right - assert (L is not None) and (U is not None) and (P is not None) - - # phi_L = L^H B_grad * 1_L - phi_L = (L.transpose(-1, -2).conj() @ LU_grad).tril_() - phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0) - # phi_U = B_grad U^H * 1_U - phi_U = (LU_grad @ U.transpose(-1, -2).conj()).triu_() - phi = phi_L + phi_U - - # using the notation from above plus the variable names, note - # A_grad = P L^{-H} phi U^{-H}. - # Instead of inverting L and U, we solve two systems of equations, i.e., - # the above expression could be rewritten as - # L^H P^T A_grad U^H = phi. - # Let X = P^T A_grad U_H, then - # X = L^{-H} phi, where L^{-H} is upper triangular, or - # X = torch.triangular_solve(phi, L^H) - # using the definition of X we see: - # X = P^T A_grad U_H => P X = A_grad U_H => U A_grad^H = X^H P^T, so - # A_grad = (U^{-1} X^H P^T)^H, or - # A_grad = torch.triangular_solve(X^H P^T, U)^H - X = torch.triangular_solve(phi, L.transpose(-1, -2).conj(), upper=True).solution - A_grad = torch.triangular_solve(X.transpose(-1, -2).conj() @ P.transpose(-1, -2), U, upper=True) \ - .solution.transpose(-1, -2).conj() - - return A_grad, None, None diff --git a/torch/_tensor.py b/torch/_tensor.py index 24811da..2bd617d 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -453,28 +453,6 @@ class Tensor(torch._C._TensorBase): if has_torch_function_unary(self): return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos) - if not torch._jit_internal.is_scripting(): - if self.requires_grad: - if not (self.size(-2) == self.size(-1) and (self.dtype.is_floating_point) or self.is_complex): - raise ValueError( - 'lu.backward works only with batches of squared full-rank matrices' - ' of floating or complex types.' - ) - - from torch._autograd_functions import _LU - LU, pivots, infos = _LU.apply(self, pivot, get_infos) - if get_infos: - return LU, pivots, infos - else: - return LU, pivots - else: - if self.requires_grad: - raise RuntimeError( - 'Script and require gradients is not supported at the moment.' - 'If you just want to do the forward, use .detach()' - 'on the input before calling the function.' - ) - LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) if get_infos: return LU, pivots, infos diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 8fc2223..86639c1 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -3831,6 +3831,178 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten return out_fw_grad; } +// Let X in \C^{m \times n}, then its pivoted LU decomposition is +// X = P L U, where P is a permutation matrix. +// +// Useful notation: +// Let o denote the elementwise, or Hadamard, product. +// k := min(m, n) +// 1 := ones(k, k), +// 1_U = 1.tril(); +// 1_L = 1 - 1_U (note the diagonal is zero) +// For a matrix A, A^H := A.transpose(-2, -1).conj() +// +// Below we derive the backward algorithm for the case when m <= n. +// The case m > n could be obtained using the same idea. +// Since we assume m <= n, the LU decomposition of X could be written as +// X = (X1 | X2) = P L (U1 | U2) where X1, U1 in \C^{m \times m}, X2, U2 in \C^{m, n - m} +// +// Forward AD: +// +// dX = P dL U + P L dU => [left-multiply P^T] +// (P^T dX1 | P^T dX2) = (dL U1 + L dU1 | dL U2 + L dU2) (*) +// From (*): +// P^T dX1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by U1^{-1}] +// L^{-1} P^T dX1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). +// Note, L is lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular. +// Also, since the diagonal of L (all ones) is never exposed explicity (packed representation), +// the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0. +// Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular. +// Combining these observations we conclude: +// +// L^{-1} dL = (L^{-1} P^T dX1 U1^{-1}) o 1_L, +// dU1 U1^{-1} = (L^{-1} P^T dX1 U1^{-1}) o 1_U. +// +// Hence, +// dL = L [(L^{-1} P^T dX1 U1^{-1}) o 1_L], +// dU1 = [(L^{-1} P^T dX1 U1^{-1}) o 1_U] U1. +// As for dU2, from (*) it follows +// P^T dX2 = dL U2 + L dU2 => +// dU2 = L^{-1} (P^T dX2 - dL U2). +// +// Backward AD: +// +// The following equality comes very handy: +// Tr(A (B o C)) = Tr((A o B^T) C) (!) +// +// Tr(X_grad^H dX) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then +// +// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dX1 U1^{-1}) o 1_L] = [using (!)] +// = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dX1 U1^{-1}) = [using the cyclic property of Tr] +// = Tr(U1^{-1} (L_grad^H L o 1_L^T) L^{-1} P^T dX1) +// +// Similar, using (!) and the cyclic property of the trace operator: +// Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2) +// = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dX1) +// + Tr(U2_grad^H L^{-1} P^T dX2) +// - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dX1) +// +// By combining the matrices to the left from dX1 and dX2 and then applying conjugate transposition, +// we finally arrive at: +// +// X1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H}, +// X2_grad = P L^{-H} U2_grad +Tensor plu_backward_base( + const variable_list& grads, + const Tensor& self, + const Tensor& P, + const Tensor& L, + const Tensor& U) { + auto L_grad = grads[0]; + auto U_grad = grads[1]; + + auto m = self.size(-2); + auto n = self.size(-1); + auto k = std::min(m, n); + + auto L_principal = L.narrow(-2, 0, k).narrow(-1, 0, k); + auto L_principal_H = L_principal.transpose(-2, -1).conj(); + auto L_grad_principal = L_grad.narrow(-2, 0, k).narrow(-1, 0, k); + auto U_principal = U.narrow(-2, 0, k).narrow(-1, 0, k); + auto U_principal_H = U_principal.transpose(-2, -1).conj(); + auto U_grad_principal = U_grad.narrow(-2, 0, k).narrow(-1, 0, k); + + auto phi_L = L_principal_H.matmul(L_grad_principal).tril_(-1); + auto phi_U = U_grad_principal.matmul(U_principal_H).triu_(); + + auto phi = phi_L + phi_U; + auto psi = at::zeros_like(self); + + Tensor self_grad; + if (m <= n) { + auto U_complement = U.narrow(-2, 0, k).narrow(-1, k, n - k); + auto U_grad_complement = U_grad.narrow(-2, 0, k).narrow(-1, k, n - k); + + auto phi_complement = U_grad_complement.matmul(U_complement.transpose(-2, -1).conj()).tril_(-1); + phi.sub_(phi_complement); + + // recall the result for X1_grad and X2_grad from above. + // It can be rewritten as + // (X1_grad | X2_grad) = P L^{-H} psi, where + // psi = (psi1 | psi2) + // = ([L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H} | U2_grad), + // so it is filled in parts. + // + // fill psi2 in + psi.narrow(-2, 0, k).narrow(-1, k, n - k).copy_(U_grad_complement); + + // solve for psi1 to avoid the inversion of U1^H + auto psi_principal = std::get<0>(at::triangular_solve( + phi.transpose(-2, -1).conj(), + U_principal, + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/false + )).transpose(-2, -1).conj(); + psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal); + + // solve for the grad to avoid the inversion of L1^H + self_grad = P.matmul( + std::get<0>(at::triangular_solve( + psi, + L_principal_H, + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/true + )) + ); + } + else { + // variables psi and phi carry the same meaning as in the case (m <= n), + // albeit they are differently defined. + auto L_complement = L.narrow(-2, k, m - k).narrow(-1, 0, k); + auto L_grad_complement = L_grad.narrow(-2, k, m - k).narrow(-1, 0, k); + + auto phi_complement = L_complement.transpose(-2, -1).conj().matmul(L_grad_complement).triu_(); + phi.sub_(phi_complement); + + psi.narrow(-2, k, m - k).narrow(-1, 0, k).copy_(L_grad_complement); + + auto psi_principal = std::get<0>(at::triangular_solve( + phi, + L_principal_H, + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/true + )); + psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal); + + self_grad = std::get<0>(at::triangular_solve( + P.matmul(psi).transpose(-2, -1), + U_principal.conj(), + /*upper=*/true, + /*transpose=*/false, + /*unitriangular=*/false + )).transpose(-2, -1); + } + + return self_grad; +} + +Tensor _lu_with_info_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& LU, + const Tensor& pivs) { + Tensor P, L, U; + std::tie(P, L, U) = at::lu_unpack(LU, pivs); + // Note that packed LU could be represented as + // LU = L + U - I, hence + // L_grad = LU_grad, + // U_grad = LU_grad. + return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U); +} + } // namespace details } // namespace generated } // namespace autograd diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 91ddec3..d397f55 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -258,6 +258,7 @@ Tensor lu_unpack_backward( const Tensor& LU_data, bool unpack_data ); + Tensor _det_lu_based_helper_backward( const Tensor& det_grad, const Tensor& det, @@ -266,6 +267,20 @@ Tensor _det_lu_based_helper_backward( const Tensor& pivs ); +Tensor lu_backward_base( + const variable_list& grads, + const Tensor& self, + const Tensor& P, + const Tensor& L, + const Tensor& U +); +Tensor _lu_with_info_backward( + const Tensor& grad, + const Tensor& self, + const Tensor& LU, + const Tensor& pivs +); + Tensor cat_jvp(at::TensorList tensors, int64_t dim); Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim); Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim); diff --git a/torch/functional.py b/torch/functional.py index 8898fc9..5d74664 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -10,7 +10,6 @@ from .overrides import ( handle_torch_function) from ._jit_internal import boolean_dispatch, List from ._jit_internal import _overload as overload -from torch._autograd_functions import _LU Tensor = torch.Tensor from torch import _VF @@ -1459,8 +1458,10 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`. .. warning:: - The LU factorization does have backward support, - but only for square inputs of full rank. + The gradients of this function will only be finite when :attr:`A` is full rank. + This is because the LU decomposition is just differentiable at full rank matrices. + Furthermore, if :attr:`A` is close to not being full rank, + the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`. Args: A (Tensor): the tensor to factor of size :math:`(*, m, n)` @@ -1508,23 +1509,6 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None): ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples! """ - if not torch._jit_internal.is_scripting(): - if A.requires_grad: - if not (A.size(-2) == A.size(-1) and (A.dtype.is_floating_point or A.is_complex)): - raise ValueError( - 'lu.backward works only with batches of squared full-rank matrices' - ' of floating or complex types.' - ) - - return _LU.apply(A, pivot, get_infos) - else: - if A.requires_grad: - raise RuntimeError( - 'Script and require gradients is not supported at the moment.' - 'If you just want to do the forward, use .detach()' - 'on the input before calling the function.' - ) - # If get_infos is True, then we don't need to check for errors and vice versa return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1c22ffa..7e923be 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3204,8 +3204,8 @@ def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs): # not needed once OpInfo tests support Iterables def generate_samples(): batch_shapes = ((), (3,), (3, 3)) - for batch_shape, get_infos in product(batch_shapes, (True, False)): - shape = batch_shape + (S, S) + for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)): + shape = batch_shape + (S + size_delta, S) input = make_tensor(shape, device, dtype, requires_grad=requires_grad, low=None, high=None) yield SampleInput(input, args=(True, get_infos)) @@ -6533,16 +6533,16 @@ op_db: List[OpInfo] = [ op=torch.lu, dtypes=floating_and_complex_types(), supports_inplace_autograd=False, + # we use in-place operations which cannot be avoided. + # This causes vmap failures, hence we skip batched gradient checks + check_batched_grad=False, check_batched_gradgrad=False, supports_out=False, sample_inputs_func=sample_inputs_lu, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], skips=( - # we skip jit tests because lu_backward is impelemented as autograd.Function, - # which does not support autograd with scripting + # we skip jit tests because `lu` is a torch function SkipInfo('TestJit', 'test_variant_consistency_jit'), - # Skip operator schema test because this is a functional and not an operator - SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), )), OpInfo('lu_solve', op=torch.lu_solve, @@ -6555,7 +6555,7 @@ op_db: List[OpInfo] = [ dtypes=floating_and_complex_types(), supports_inplace_autograd=False, # we use in-place operations which cannot be avoided. - # This cases vmap failures, hence we skip batched gradient checks + # This causes vmap failures, hence we skip batched gradient checks check_batched_grad=False, supports_out=True, sample_inputs_func=sample_inputs_lu_unpack,