Enable upper for torch.linalg.cholesky (#62434)
authorRong Rong (AI Infra) <rongr@fb.com>
Mon, 9 Aug 2021 16:26:47 +0000 (09:26 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 9 Aug 2021 16:28:33 +0000 (09:28 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61988

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62434

Reviewed By: seemethere, tktrungna

Differential Revision: D30079806

Pulled By: walterddr

fbshipit-source-id: 044efb96525155c9bc7953ac4ad47c1b7c12fb20

aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/native_functions.yaml
test/backward_compatibility/check_backward_compatibility.py
test/test_linalg.py
tools/autograd/derivatives.yaml
torch/linalg/__init__.py

index b8b3c68..d80f918 100644 (file)
@@ -1322,7 +1322,7 @@ Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) {
   return result;
 }
 
-void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info) {
+void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info, bool upper) {
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-1) == input.size(-2));
 
@@ -1356,12 +1356,16 @@ void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const T
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.sizes().equals(expected_info_shape));
   info.fill_(0);
 
-  cholesky_stub(result.device().type(), result, info, /*upper=*/false);
+  cholesky_stub(result.device().type(), result, info, upper);
 
-  result.tril_();
+  if (upper) {
+    result.triu_();
+  } else {
+    result.tril_();
+  }
 }
 
-std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool check_errors, Tensor& L, Tensor& info) {
+std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool upper, bool check_errors, Tensor& L, Tensor& info) {
   squareCheckInputs(input);
   checkSameDevice("torch.linalg.cholesky_ex", L, input, "L");
   checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L");
@@ -1397,14 +1401,14 @@ std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool ch
   if (copy_needed) {
     Tensor L_tmp = at::empty({0}, input.options());
     Tensor info_tmp = at::empty({0}, input.options().dtype(kInt));
-    linalg_cholesky_out_info(input, L_tmp, info_tmp);
+    linalg_cholesky_out_info(input, L_tmp, info_tmp, upper);
     at::native::resize_output(L, L_tmp.sizes());
     L.copy_(L_tmp);
     at::native::resize_output(info, info_tmp.sizes());
     info.copy_(info_tmp);
   } else {
     // use "out" tensors' memory directly
-    linalg_cholesky_out_info(input, L, info);
+    linalg_cholesky_out_info(input, L, info, upper);
   }
 
   if (check_errors) {
@@ -1418,16 +1422,16 @@ std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool ch
   return std::tuple<Tensor&, Tensor&>(L, info);
 }
 
-std::tuple<Tensor, Tensor> linalg_cholesky_ex(const Tensor& input, bool check_errors) {
+std::tuple<Tensor, Tensor> linalg_cholesky_ex(const Tensor& input, bool upper, bool check_errors) {
   Tensor L = at::empty({0}, input.options());
   Tensor info = at::empty({0}, input.options().dtype(kInt));
-  std::tie(L, info) = at::native::linalg_cholesky_ex_out(input, check_errors, L, info);
+  std::tie(L, info) = at::native::linalg_cholesky_ex_out(input, upper, check_errors, L, info);
   return std::make_tuple(L, info);
 }
 
-Tensor linalg_cholesky(const Tensor &self) {
+Tensor linalg_cholesky(const Tensor &self, bool upper) {
   Tensor result, info;
-  std::tie(result, info) = at::linalg_cholesky_ex(self, /*check_errors=*/false);
+  std::tie(result, info) = at::linalg_cholesky_ex(self, upper, /*check_errors=*/false);
 
   // we pass check_errors=false above and do the check here
   // so that the name of the function is correct in the error message
@@ -1440,14 +1444,14 @@ Tensor linalg_cholesky(const Tensor &self) {
   return result;
 }
 
-Tensor& linalg_cholesky_out(const Tensor &self, Tensor &result) {
+Tensor& linalg_cholesky_out(const Tensor &self, bool upper, Tensor &result) {
   // linalg_cholesky_ex_outf includes these checks, but we do it here
   // so that the name of the function is correct in the error message
   checkSameDevice("torch.linalg.cholesky", result, self);
   checkLinalgCompatibleDtype("torch.linalg.cholesky", result, self);
 
   Tensor info = at::empty({0}, self.options().dtype(kInt));
-  std::tie(result, info) = at::linalg_cholesky_ex_outf(self, /*check_errors=*/false, result, info);
+  std::tie(result, info) = at::linalg_cholesky_ex_outf(self, upper, /*check_errors=*/false, result, info);
 
   // we pass check_errors=false above and do the check here
   // so that the name of the function is correct in the error message
index 3e829ae..6c6a77a 100644 (file)
 # See linalg_det as an example.
 
 # "_ex" stands for experimental
-- func: linalg_cholesky_ex(Tensor self, *, bool check_errors=False) -> (Tensor L, Tensor info)
+- func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
   python_module: linalg
   variants: function
   dispatch:
     CPU, CUDA: linalg_cholesky_ex
 
-- func: linalg_cholesky_ex.L(Tensor self, *, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
+- func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
   python_module: linalg
   variants: function
   dispatch:
     CPU, CUDA: linalg_cholesky_ex_out
 
-- func: linalg_cholesky(Tensor self) -> Tensor
+- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
   python_module: linalg
   variants: function
 
-- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
   python_module: linalg
   variants: function
 
index ce23f90..326e038 100644 (file)
@@ -42,6 +42,9 @@ ALLOW_LIST = [
     ("aten::irfft", datetime.date(2021, 1, 31)),
     ("aten::rfft", datetime.date(2021, 1, 31)),
     ("aten::linalg_svd", datetime.date(2021, 5, 15)),
+    ("aten::linalg_cholesky.out", datetime.date(2021, 8, 30)),
+    ("aten::linalg_cholesky_ex", datetime.date(2021, 8, 30)),
+    ("aten::linalg_cholesky_ex.L", datetime.date(2021, 8, 30)),
     ("aten::_cholesky_helper", datetime.date(9999, 1, 1)),
     ("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
     ("aten::linalg_lstsq", datetime.date(2021, 5, 1)),
index 44d2d14..119da49 100644 (file)
@@ -454,6 +454,11 @@ class TestLinalg(TestCase):
         expected = torch.linalg.cholesky(A)
         self.assertEqual(expected, out)
 
+        # check the upper= variant
+        expected = torch.linalg.cholesky(A).transpose(-2, -1).conj()
+        actual = torch.linalg.cholesky(A, upper=True)
+        self.assertEqual(expected, actual)
+
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
     @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
index ea37133..3aee1ff 100644 (file)
 - name: cholesky(Tensor self, bool upper=False) -> Tensor
   self: cholesky_backward(grad, upper, result)
 
-- name: linalg_cholesky_ex(Tensor self, *, bool check_errors=False) -> (Tensor L, Tensor info)
-  self: cholesky_backward(grad, false, L)
+- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
+  self: cholesky_backward(grad, upper, L)
 
 - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
   self, input2: cholesky_solve_backward(grad, self, input2, result, upper)
index a9bfaf6..df3507f 100644 (file)
@@ -15,7 +15,7 @@ common_notes = {
 # also connects the torch.linalg Python namespace to the torch._C._linalg builtins.
 
 cholesky = _add_docstr(_linalg.linalg_cholesky, r"""
-linalg.cholesky(A, *, out=None) -> Tensor
+linalg.cholesky(A, *, upper=False, out=None) -> Tensor
 
 Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
 
@@ -54,6 +54,9 @@ Args:
                 consisting of symmetric or Hermitian positive-definite matrices.
 
 Keyword args:
+    upper (bool, optional): whether to return an upper triangular matrix.
+        The tensor returned with upper=True is the conjugate transpose of the tensor
+        returned with upper=False.
     out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
 
 Raises:
@@ -84,7 +87,7 @@ Examples::
 """)
 
 cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r"""
-linalg.cholesky_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor)
+linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor)
 
 Computes the Cholesky decomposition of a complex Hermitian or real
 symmetric positive-definite matrix.
@@ -118,9 +121,12 @@ If ``check_errors=True`` and ``info`` contains positive integers, then a Runtime
 Args:
     A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size
                     `(*, n, n)` where `*` is one or more batch dimensions.
-    check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`.
 
 Keyword args:
+    upper (bool, optional): whether to return an upper triangular matrix.
+        The tensor returned with upper=True is the conjugate transpose of the tensor
+        returned with upper=False.
+    check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`.
     out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`.
 
 Examples::