From: Nikita Vedeneev Date: Thu, 9 Sep 2021 15:56:29 +0000 (-0700) Subject: `torch.lu_solve`: forward AD support (#64646) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~336 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dc535466557f190e3014f9a94b24dfa500119c02;p=platform%2Fupstream%2Fpytorch.git `torch.lu_solve`: forward AD support (#64646) Summary: As per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/64646 Reviewed By: VitalyFedyunin Differential Revision: D30807898 Pulled By: albanD fbshipit-source-id: 1f943c22357dd1b3662cfe0d2a26af68e3a2df4c --- diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 505130f..79ef447 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -866,6 +866,7 @@ - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots) + result: lu_solve_forward_AD(self_t, LU_data_t, LU_data_p, LU_pivots, result) - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) LU_data: lu_unpack_backward(grads, LU_data, unpack_data) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 9ccfbd1..6f05646 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -3746,6 +3746,42 @@ std::tuple lu_solve_backward( return std::make_tuple(self_grad, LU_data_grad); } +Tensor lu_solve_forward_AD( + const Tensor& dB, + const Tensor& dLU_data, + const Tensor& LU_data, + const Tensor& LU_pivots, + const Tensor& X +) { + auto dL = dLU_data.tril(-1); + auto dU = dLU_data.triu(); + + // From the derivations from above we have that: + // dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB, + // or, using that X = (LU)^{-1} P^T B, + // dX = -U^{-1} dU X - (LU)^{-1} dL U X + (LU)^{-1} P^T dB + + // -U^{-1} dU X + auto U = LU_data.triu(); + auto dU_part = -std::get<0>(at::triangular_solve( + dU.matmul(X), + U, + /*upper=*/true + )); + + // (LU)^{-1} dL U X, + // we use lu_solve to solve this system which requires pivots which are returned by the lu routine. + // Since no pivoting is required for the system, we create a tensor of identity permutations + // which are 1-based because of the Fortran-like LAPACK interfaces. + auto identity_pivots = at::arange(1, LU_data.size(-1) + 1, LU_pivots.options()).expand(LU_pivots.sizes()); + auto dL_part = at::lu_solve(dL.matmul(U).matmul(X), LU_data, identity_pivots); + + // (LU)^{-1} P^T dB + auto dB_part = at::lu_solve(dB, LU_data, LU_pivots); + + return dU_part - dL_part + dB_part; +} + Tensor lu_unpack_backward( const variable_list& grads, const Tensor& LU_data, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 6684bcb..1aa1062 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -253,6 +253,13 @@ std::tuple lu_solve_backward( const Tensor& LU_data, const Tensor& LU_pivots ); +Tensor lu_solve_forward_AD( + const Tensor& B_t, + const Tensor& LU_data_t, + const Tensor& LU_data, + const Tensor& LU_pivots, + const Tensor& X +); Tensor lu_unpack_backward( const variable_list& grads, const Tensor& LU_data, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b1fec69..40ae6b4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7101,6 +7101,7 @@ op_db: List[OpInfo] = [ op=torch.lu_solve, dtypes=floating_and_complex_types(), check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_lu_solve, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), OpInfo('lu_unpack',