From dd8f6ac59784472d499d1594ca4e348f92337861 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 8 Sep 2021 09:34:46 -0700 Subject: [PATCH] Add forward mode differentiation for torch.linalg.cholesky and transpose (#62159) Summary: This PR adds forward mode differentiation for `torch.linalg.cholesky`, `torch.linalg.cholesky_ex`, and `transpose` functions. Complex tests for Cholesky fail because for some reason the gradcheck sends matrices full of zeros to `cholesky_jvp` function. cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62159 Reviewed By: mrshenli Differential Revision: D30776829 Pulled By: albanD fbshipit-source-id: 32e5539ed6423eed8c18cce16271330ab0ea8d5e --- tools/autograd/derivatives.yaml | 1 + torch/csrc/autograd/FunctionsManual.cpp | 14 ++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 1 + torch/testing/_internal/common_methods_invocations.py | 12 +++++++++--- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 641471e..4bdb565 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -339,6 +339,7 @@ - name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) self: cholesky_backward(grad, upper, L) + L: cholesky_jvp(self_t, L, upper) - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 95170f0..9ccfbd1 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1006,6 +1006,20 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra return mask_selected.view(sizes); } +Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) { + // Differentiation of the Cholesky decomposition, Iain Murray + // https://arxiv.org/abs/1602.07527 + // equation 8 + auto input_tangent_ = upper ? input_tangent.transpose(-1, -2).conj() : input_tangent; + auto L_ = upper ? L.transpose(-1, -2).conj() : L; + + auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L_, /*upper=*/false)); + auto phi = at::matmul(at::matmul(L_inverse, input_tangent_), L_inverse.transpose(-2, -1).conj()); + phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); + auto L_tangent = L_.matmul(phi); + return upper ? L_tangent.transpose(-1, -2).conj() : L_tangent; +} + Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { // cf. Iain Murray (2016); arXiv 1602.07527 // This gradient is symmetric, and not triangular. diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 31a972e..6684bcb 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -101,6 +101,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional dim, c10::optional correction, bool keepdim, bool is_std); at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes); at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L); +at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& L, bool upper); at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse); at::Tensor split_with_sizes_backward(const std::vector &grads, IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index bccac91..b1fec69 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3532,10 +3532,10 @@ def sample_inputs_linalg_cholesky(op_info, device, dtype, requires_grad=False, * batches = [(), (0, ), (2, ), (1, 1)] ns = [5, 0] out = [] - for batch, n in product(batches, ns): + for batch, n, upper in product(batches, ns, [True, False]): a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) a.requires_grad = requires_grad - out.append(SampleInput(a)) + out.append(SampleInput(a, kwargs={"upper": upper})) return out def sample_inputs_symeig(op_info, device, dtype, requires_grad=False): @@ -6840,6 +6840,7 @@ op_db: List[OpInfo] = [ # got: vmap: Calling Tensor.as_strided is not supported # unless the batch dims being vmapped over are at the front of the tensor (in memory layout). check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], @@ -6853,9 +6854,14 @@ op_db: List[OpInfo] = [ aten_name='linalg_cholesky_ex', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck for complex generates invalid inputs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), + ), OpInfo('linalg.cond', aten_name='linalg_cond', dtypes=floating_and_complex_types(), -- 2.7.4