From 88fff22023b201ee237ab0856d53a154cc1784bb Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Fri, 10 Sep 2021 07:17:30 -0700 Subject: [PATCH] `torch.lu`: forward AD support (#64742) Summary: As per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64742 Reviewed By: H-Huang Differential Revision: D30841227 Pulled By: albanD fbshipit-source-id: dc4d043ab94358594adb110fbbbb60750c98262a --- tools/autograd/derivatives.yaml | 1 + torch/csrc/autograd/FunctionsManual.cpp | 85 ++++++++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 5 ++ .../_internal/common_methods_invocations.py | 1 + 4 files changed, 92 insertions(+) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 79ef447..660c188 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -863,6 +863,7 @@ - name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) self: _lu_with_info_backward(grad, self, LU, pivots) + LU: _lu_with_info_jvp(self_t, LU, pivots) - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 6f05646..e53a0c5 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4051,6 +4051,91 @@ Tensor _lu_with_info_backward( return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U); } +Tensor _lu_with_info_jvp( + const Tensor& dX, + const Tensor& LU, + const Tensor& pivs +) { + // This function is based on the forward AD derivations outlined + // in the description to the plu_backward_base function. + + Tensor P, L, U; + std::tie(P, L, U) = at::lu_unpack(LU, pivs); + + auto m = LU.size(-2); + auto n = LU.size(-1); + auto k = std::min(m, n); + + auto pdX = P.transpose(-1, -2).matmul(dX); + + // similar to the backward implementation, we also consider block structures such as: + // for a matrix A of size m x n we decompose it as + // A = (A1 | A2) with A1 of size m x m if m <= n and + // A = (A1^T | A2^T)^T with A1 of size n x n if m > n. + auto pdX1 = pdX.narrow(-2, 0, k).narrow(-1, 0, k); + auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k); + auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k); + + // dK = L1^{-1} pdX1 + auto dK = std::get<0>(at::triangular_solve( + pdX1, + L1, + /*upper=*/false, + /*transpose=*/false, + /*unitriangular=*/true + )); + // dK <- dK U1^{-1} + dK = std::get<0>(at::triangular_solve( + dK.transpose(-1, -2), + U1, + /*upper=*/true, + /*transpose=*/true + )).transpose(-1, -2); + + auto dL1 = L1.matmul(dK.tril(-1)); + auto dU1 = dK.triu().matmul(U1); + + // since LU = L + U - I, we have that dLU = dL + dU + // if LU is of size m x n, we always have + // dLU1 = dL1 + dU1, where the block indexing follows the rules + // outlined above. + if (m == n) { + return dL1 + dU1; + } + else { + auto dLU = at::zeros_like(LU); + dLU.narrow(-2, 0, k).narrow(-1, 0, k).copy_(dL1 + dU1); + + if (m < n) { + // we only need to update dU2 defined as + // dU2 := L1^{-1} (pdX2 - dL1 U2) + auto pdX2 = pdX.narrow(-1, k, n - k); + auto U2 = U.narrow(-1, k, n - k); + dLU.narrow(-1, k, n - k).copy_(std::get<0>(at::triangular_solve( + pdX2 - dL1.matmul(U2), + L1, + /*upper=*/false, + /*transpose=*/false, + /*unitriangular=*/true + ))); + } + else { + // we only need to update dL2 defined as + // dL2 := (pdX2 - L2 dU1) U1^{-1} + auto pdX2 = pdX.narrow(-2, k, m - k); + auto L2 = L.narrow(-2, k, m - k); + dLU.narrow(-2, k, m - k).copy_(std::get<0>(at::triangular_solve( + (pdX2 - L2.matmul(dU1)).transpose(-1, -2), + U1, + /*upper=*/true, + /*transpose=*/true + )).transpose(-1, -2)); + } + + return dLU; + } +} + } // namespace details } // namespace generated } // namespace autograd diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 1aa1062..b24dce7 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -287,6 +287,11 @@ Tensor _lu_with_info_backward( const Tensor& LU, const Tensor& pivs ); +Tensor _lu_with_info_jvp( + const Tensor& dX, + const Tensor& LU, + const Tensor& pivs +); Tensor cat_jvp(at::TensorList tensors, int64_t dim); Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f0c163c..8451791 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7090,6 +7090,7 @@ op_db: List[OpInfo] = [ # This causes vmap failures, hence we skip batched gradient checks check_batched_grad=False, check_batched_gradgrad=False, + supports_forward_ad=True, supports_out=False, sample_inputs_func=sample_inputs_lu, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], -- 2.7.4