From: Rong Rong (AI Infra) Date: Mon, 9 Aug 2021 16:26:47 +0000 (-0700) Subject: Enable upper for torch.linalg.cholesky (#62434) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~1165 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3782f3ecedad516ca2694cc05bb2287a17fafcd1;p=platform%2Fupstream%2Fpytorch.git Enable upper for torch.linalg.cholesky (#62434) Summary: Fixes https://github.com/pytorch/pytorch/issues/61988 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62434 Reviewed By: seemethere, tktrungna Differential Revision: D30079806 Pulled By: walterddr fbshipit-source-id: 044efb96525155c9bc7953ac4ad47c1b7c12fb20 --- diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index b8b3c68..d80f918 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1322,7 +1322,7 @@ Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) { return result; } -void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info) { +void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const Tensor& info, bool upper) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-1) == input.size(-2)); @@ -1356,12 +1356,16 @@ void linalg_cholesky_out_info(const Tensor& input, const Tensor& result, const T TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info.sizes().equals(expected_info_shape)); info.fill_(0); - cholesky_stub(result.device().type(), result, info, /*upper=*/false); + cholesky_stub(result.device().type(), result, info, upper); - result.tril_(); + if (upper) { + result.triu_(); + } else { + result.tril_(); + } } -std::tuple linalg_cholesky_ex_out(const Tensor& input, bool check_errors, Tensor& L, Tensor& info) { +std::tuple linalg_cholesky_ex_out(const Tensor& input, bool upper, bool check_errors, Tensor& L, Tensor& info) { squareCheckInputs(input); checkSameDevice("torch.linalg.cholesky_ex", L, input, "L"); checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L"); @@ -1397,14 +1401,14 @@ std::tuple linalg_cholesky_ex_out(const Tensor& input, bool ch if (copy_needed) { Tensor L_tmp = at::empty({0}, input.options()); Tensor info_tmp = at::empty({0}, input.options().dtype(kInt)); - linalg_cholesky_out_info(input, L_tmp, info_tmp); + linalg_cholesky_out_info(input, L_tmp, info_tmp, upper); at::native::resize_output(L, L_tmp.sizes()); L.copy_(L_tmp); at::native::resize_output(info, info_tmp.sizes()); info.copy_(info_tmp); } else { // use "out" tensors' memory directly - linalg_cholesky_out_info(input, L, info); + linalg_cholesky_out_info(input, L, info, upper); } if (check_errors) { @@ -1418,16 +1422,16 @@ std::tuple linalg_cholesky_ex_out(const Tensor& input, bool ch return std::tuple(L, info); } -std::tuple linalg_cholesky_ex(const Tensor& input, bool check_errors) { +std::tuple linalg_cholesky_ex(const Tensor& input, bool upper, bool check_errors) { Tensor L = at::empty({0}, input.options()); Tensor info = at::empty({0}, input.options().dtype(kInt)); - std::tie(L, info) = at::native::linalg_cholesky_ex_out(input, check_errors, L, info); + std::tie(L, info) = at::native::linalg_cholesky_ex_out(input, upper, check_errors, L, info); return std::make_tuple(L, info); } -Tensor linalg_cholesky(const Tensor &self) { +Tensor linalg_cholesky(const Tensor &self, bool upper) { Tensor result, info; - std::tie(result, info) = at::linalg_cholesky_ex(self, /*check_errors=*/false); + std::tie(result, info) = at::linalg_cholesky_ex(self, upper, /*check_errors=*/false); // we pass check_errors=false above and do the check here // so that the name of the function is correct in the error message @@ -1440,14 +1444,14 @@ Tensor linalg_cholesky(const Tensor &self) { return result; } -Tensor& linalg_cholesky_out(const Tensor &self, Tensor &result) { +Tensor& linalg_cholesky_out(const Tensor &self, bool upper, Tensor &result) { // linalg_cholesky_ex_outf includes these checks, but we do it here // so that the name of the function is correct in the error message checkSameDevice("torch.linalg.cholesky", result, self); checkLinalgCompatibleDtype("torch.linalg.cholesky", result, self); Tensor info = at::empty({0}, self.options().dtype(kInt)); - std::tie(result, info) = at::linalg_cholesky_ex_outf(self, /*check_errors=*/false, result, info); + std::tie(result, info) = at::linalg_cholesky_ex_outf(self, upper, /*check_errors=*/false, result, info); // we pass check_errors=false above and do the check here // so that the name of the function is correct in the error message diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3e829ae..6c6a77a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10170,23 +10170,23 @@ # See linalg_det as an example. # "_ex" stands for experimental -- func: linalg_cholesky_ex(Tensor self, *, bool check_errors=False) -> (Tensor L, Tensor info) +- func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) python_module: linalg variants: function dispatch: CPU, CUDA: linalg_cholesky_ex -- func: linalg_cholesky_ex.L(Tensor self, *, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) +- func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) python_module: linalg variants: function dispatch: CPU, CUDA: linalg_cholesky_ex_out -- func: linalg_cholesky(Tensor self) -> Tensor +- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor python_module: linalg variants: function -- func: linalg_cholesky.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) python_module: linalg variants: function diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index ce23f90..326e038 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -42,6 +42,9 @@ ALLOW_LIST = [ ("aten::irfft", datetime.date(2021, 1, 31)), ("aten::rfft", datetime.date(2021, 1, 31)), ("aten::linalg_svd", datetime.date(2021, 5, 15)), + ("aten::linalg_cholesky.out", datetime.date(2021, 8, 30)), + ("aten::linalg_cholesky_ex", datetime.date(2021, 8, 30)), + ("aten::linalg_cholesky_ex.L", datetime.date(2021, 8, 30)), ("aten::_cholesky_helper", datetime.date(9999, 1, 1)), ("aten::_lstsq_helper", datetime.date(9999, 1, 1)), ("aten::linalg_lstsq", datetime.date(2021, 5, 1)), diff --git a/test/test_linalg.py b/test/test_linalg.py index 44d2d14..119da49 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -454,6 +454,11 @@ class TestLinalg(TestCase): expected = torch.linalg.cholesky(A) self.assertEqual(expected, out) + # check the upper= variant + expected = torch.linalg.cholesky(A).transpose(-2, -1).conj() + actual = torch.linalg.cholesky(A, upper=True) + self.assertEqual(expected, actual) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index ea37133..3aee1ff 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -337,8 +337,8 @@ - name: cholesky(Tensor self, bool upper=False) -> Tensor self: cholesky_backward(grad, upper, result) -- name: linalg_cholesky_ex(Tensor self, *, bool check_errors=False) -> (Tensor L, Tensor info) - self: cholesky_backward(grad, false, L) +- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + self: cholesky_backward(grad, upper, L) - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index a9bfaf6..df3507f 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -15,7 +15,7 @@ common_notes = { # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. cholesky = _add_docstr(_linalg.linalg_cholesky, r""" -linalg.cholesky(A, *, out=None) -> Tensor +linalg.cholesky(A, *, upper=False, out=None) -> Tensor Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. @@ -54,6 +54,9 @@ Args: consisting of symmetric or Hermitian positive-definite matrices. Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. Raises: @@ -84,7 +87,7 @@ Examples:: """) cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r""" -linalg.cholesky_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) +linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor) Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. @@ -118,9 +121,12 @@ If ``check_errors=True`` and ``info`` contains positive integers, then a Runtime Args: A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size `(*, n, n)` where `*` is one or more batch dimensions. - check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. Keyword args: + upper (bool, optional): whether to return an upper triangular matrix. + The tensor returned with upper=True is the conjugate transpose of the tensor + returned with upper=False. + check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. Examples::