Add forward mode differentiation for torch.linalg.cholesky and transpose (#62159)
authorIvan Yashchuk <ivan.yashchuk@aalto.fi>
Wed, 8 Sep 2021 16:34:46 +0000 (09:34 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 16:44:30 +0000 (09:44 -0700)
Summary:
This PR adds forward mode differentiation for `torch.linalg.cholesky`, `torch.linalg.cholesky_ex`, and `transpose` functions.
Complex tests for Cholesky fail because for some reason the gradcheck sends matrices full of zeros to `cholesky_jvp` function.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62159

Reviewed By: mrshenli

Differential Revision: D30776829

Pulled By: albanD

fbshipit-source-id: 32e5539ed6423eed8c18cce16271330ab0ea8d5e

tools/autograd/derivatives.yaml
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/testing/_internal/common_methods_invocations.py

index 641471e..4bdb565 100644 (file)
 
 - name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
   self: cholesky_backward(grad, upper, L)
+  L: cholesky_jvp(self_t, L, upper)
 
 - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
   self, input2: cholesky_solve_backward(grad, self, input2, result, upper)
index 95170f0..9ccfbd1 100644 (file)
@@ -1006,6 +1006,20 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra
   return mask_selected.view(sizes);
 }
 
+Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) {
+  // Differentiation of the Cholesky decomposition, Iain Murray
+  // https://arxiv.org/abs/1602.07527
+  // equation 8
+  auto input_tangent_ = upper ? input_tangent.transpose(-1, -2).conj() : input_tangent;
+  auto L_ = upper ? L.transpose(-1, -2).conj() : L;
+
+  auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L_, /*upper=*/false));
+  auto phi = at::matmul(at::matmul(L_inverse, input_tangent_), L_inverse.transpose(-2, -1).conj());
+  phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
+  auto L_tangent = L_.matmul(phi);
+  return upper ? L_tangent.transpose(-1, -2).conj() : L_tangent;
+}
+
 Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
   // cf. Iain Murray (2016); arXiv 1602.07527
   // This gradient is symmetric, and not triangular.
index 31a972e..6684bcb 100644 (file)
@@ -101,6 +101,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n
 at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction, bool keepdim, bool is_std);
 at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes);
 at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L);
+at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& L, bool upper);
 at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse);
 at::Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
                                      IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options);
index bccac91..b1fec69 100644 (file)
@@ -3532,10 +3532,10 @@ def sample_inputs_linalg_cholesky(op_info, device, dtype, requires_grad=False, *
     batches = [(), (0, ), (2, ), (1, 1)]
     ns = [5, 0]
     out = []
-    for batch, n in product(batches, ns):
+    for batch, n, upper in product(batches, ns, [True, False]):
         a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
         a.requires_grad = requires_grad
-        out.append(SampleInput(a))
+        out.append(SampleInput(a, kwargs={"upper": upper}))
     return out
 
 def sample_inputs_symeig(op_info, device, dtype, requires_grad=False):
@@ -6840,6 +6840,7 @@ op_db: List[OpInfo] = [
            # got: vmap: Calling Tensor.as_strided is not supported
            # unless the batch dims being vmapped over are at the front of the tensor (in memory layout).
            check_batched_gradgrad=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_linalg_cholesky,
            gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
@@ -6853,9 +6854,14 @@ op_db: List[OpInfo] = [
            aten_name='linalg_cholesky_ex',
            dtypes=floating_and_complex_types(),
            check_batched_gradgrad=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_linalg_cholesky,
            gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
-           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
+           skips=(
+               # Gradcheck for complex generates invalid inputs for this function
+               SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
+           ),
     OpInfo('linalg.cond',
            aten_name='linalg_cond',
            dtypes=floating_and_complex_types(),