From 487388d8ad07e8e5630edf6887b8246466fe4156 Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Tue, 9 Apr 2019 15:15:06 -0700 Subject: [PATCH] Rename btrisolve to lu_solve (#18726) Summary: Changelog: - Rename `btrisolve` to `lu_solve` to remain consistent with names of solve methods (`cholesky_solve`, `triangular_solve`, `solve`) - Fix all callsites - Rename all tests - Create a tentative alias for `lu_solve` under the name `btrisolve` and add a deprecation warning to not promote usage Pull Request resolved: https://github.com/pytorch/pytorch/pull/18726 Differential Revision: D14726237 Pulled By: zou3519 fbshipit-source-id: bf25f6c79062183a4153015e0ec7ebab2c8b986b --- aten/src/ATen/core/Tensor.h | 2 +- aten/src/ATen/core/TensorMethods.h | 4 +-- aten/src/ATen/core/Type.h | 2 +- aten/src/ATen/core/aten_interned_strings.h | 2 +- aten/src/ATen/native/LegacyDefinitions.cpp | 16 +++++----- aten/src/ATen/native/native_functions.yaml | 4 +-- docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_cuda.py | 4 +-- test/test_torch.py | 8 ++--- tools/autograd/derivatives.yaml | 6 ++-- torch/_tensor_docs.py | 14 ++++----- torch/_torch_docs.py | 50 ++++++++++++++++-------------- torch/functional.py | 19 +++++++++++- torch/tensor.py | 11 +++++-- 15 files changed, 86 insertions(+), 58 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 9236074..009ad2e 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -712,7 +712,7 @@ class CAFFE2_API Tensor { std::tuple geqrf() const; Tensor orgqr(const Tensor & input2) const; Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const; - Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const; + Tensor lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const; Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const; Tensor lgamma() const; Tensor digamma() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 7509f59..065af04 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1192,8 +1192,8 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const { inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { return dispatch_type().ormqr(*this, input2, input3, left, transpose); } -inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const { - return dispatch_type().btrisolve(*this, LU_data, LU_pivots); +inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const { + return dispatch_type().lu_solve(*this, LU_data, LU_pivots); } inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const { return dispatch_type().multinomial(*this, num_samples, replacement, generator); diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index ad6000e..97871be 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -587,7 +587,7 @@ struct CAFFE2_API Type { virtual std::tuple geqrf(const Tensor & self) const = 0; virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0; virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0; - virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0; + virtual Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0; virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0; virtual Tensor lgamma(const Tensor & self) const = 0; virtual Tensor digamma(const Tensor & self) const = 0; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index ca3fe63..7297689 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -225,7 +225,6 @@ _(aten, bincount) \ _(aten, blackman_window) \ _(aten, bmm) \ _(aten, broadcast_tensors) \ -_(aten, btrisolve) \ _(aten, cartesian_prod) \ _(aten, cat) \ _(aten, cauchy) \ @@ -416,6 +415,7 @@ _(aten, logsumexp) \ _(aten, lstm) \ _(aten, lstm_cell) \ _(aten, lt) \ +_(aten, lu_solve) \ _(aten, margin_ranking_loss) \ _(aten, masked_fill) \ _(aten, masked_scatter) \ diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index 6315d39..60dce0b 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -508,20 +508,20 @@ Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, return at::legacy::th::_th_ormqr(self, input2, input3, left, transpose); } -std::tuple _multinomial_alias_setup(const Tensor & probs) { - return at::legacy::th::_th_multinomial_alias_setup(probs); +Tensor & lu_solve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) { + return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots); } -Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) { - return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator); +Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) { + return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots); } -Tensor & btrisolve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) { - return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots); +std::tuple _multinomial_alias_setup(const Tensor & probs) { + return at::legacy::th::_th_multinomial_alias_setup(probs); } -Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) { - return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots); +Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) { + return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator); } Tensor & multinomial_out(Tensor & result, const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ed3226d..a3406a6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3902,10 +3902,10 @@ CPU: _lu_with_info_cpu CUDA: _lu_with_info_cuda -- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) +- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True -- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor +- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor matches_jit_signature: True variants: method, function diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index de5a28d..af10e8e 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -308,6 +308,7 @@ view of a storage and defines numeric operations on it. .. automethod:: lt .. automethod:: lt_ .. automethod:: lu + .. automethod:: lu_solve .. automethod:: map_ .. automethod:: masked_scatter_ .. automethod:: masked_scatter diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 85bf6d5..039fd0f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -317,6 +317,7 @@ BLAS and LAPACK Operations .. autofunction:: logdet .. autofunction:: slogdet .. autofunction:: lu +.. autofunction:: lu_solve .. autofunction:: lu_unpack .. autofunction:: matmul .. autofunction:: matrix_power diff --git a/test/test_cuda.py b/test/test_cuda.py index f10ebde..056652d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2362,8 +2362,8 @@ class TestCuda(TestCase): @skipIfRocm @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_btrisolve(self): - _TestTorchMixin._test_btrisolve(self, lambda t: t.cuda()) + def test_lu_solve(self): + _TestTorchMixin._test_lu_solve(self, lambda t: t.cuda()) @skipIfRocm @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") diff --git a/test/test_torch.py b/test/test_torch.py index 97bb227..6903fdd 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1773,7 +1773,7 @@ class _TestTorchMixin(object): self._test_lu(self, lambda t: t) @staticmethod - def _test_btrisolve(self, cast): + def _test_lu_solve(self, cast): a = torch.FloatTensor((((1.3722, -0.9020), (1.8849, 1.9169)), ((0.7187, -1.1695), @@ -1786,13 +1786,13 @@ class _TestTorchMixin(object): a, b = cast(a), cast(b) LU_data, pivots, info = a.lu(get_infos=True) self.assertEqual(info.abs().sum(), 0) - x = torch.btrisolve(b, LU_data, pivots) + x = torch.lu_solve(b, LU_data, pivots) b_ = torch.bmm(a, x.unsqueeze(2)).squeeze() self.assertEqual(b_, b) @skipIfNoLapack - def test_btrisolve(self): - self._test_btrisolve(self, lambda t: t) + def test_lu_solve(self): + self._test_lu_solve(self, lambda t: t) @staticmethod def _test_lu_unpack(self, cast): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1ba68ce..ab6bb84 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -192,9 +192,6 @@ self: grad.bmm(mat2.transpose(1, 2)) mat2: self.transpose(1, 2).bmm(grad) -- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) - self: not_implemented("btrisolve") - - name: cat(TensorList tensors, int64_t dim) tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim) @@ -467,6 +464,9 @@ - name: _lu_with_info(Tensor self, bool pivot, bool check_errors) self: not_implemented("lu_with_info") +- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) + self: not_implemented("lu_solve") + - name: masked_fill_(Tensor self, Tensor mask, Scalar value) self: grad.clone().masked_fill_(mask, 0) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index e9f5ba6..262998c 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -486,13 +486,6 @@ bmm(batch2) -> Tensor See :func:`torch.bmm` """) -add_docstr_all('btrisolve', - r""" -btrisolve(LU_data, LU_pivots) -> Tensor - -See :func:`torch.btrisolve` -""") - add_docstr_all('cauchy_', r""" cauchy_(median=0, sigma=1, *, generator=None) -> Tensor @@ -1419,6 +1412,13 @@ lt_(other) -> Tensor In-place version of :meth:`~Tensor.lt` """) +add_docstr_all('lu_solve', + r""" +lu_solve(LU_data, LU_pivots) -> Tensor + +See :func:`torch.lu_solve` +""") + add_docstr_all('map_', r""" map_(tensor, callable) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9c3bb21..522fcca 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2564,6 +2564,32 @@ Example:: [ 1, 0]], dtype=torch.uint8) """) +add_docstr(torch.lu_solve, + r""" +lu_solve(b, LU_data, LU_pivots, out=None) -> Tensor + +Batch LU solve. + +Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted +LU factorization of A from :meth:`torch.lu`. + +Arguments: + b (Tensor): the RHS tensor + LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`. + LU_pivots (IntTensor): the pivots of the LU factorization + out (Tensor, optional): the optional output tensor + +Example:: + + >>> A = torch.randn(2, 3, 3) + >>> b = torch.randn(2, 3) + >>> A_LU = torch.lu(A) + >>> x = torch.lu_solve(b, *A_LU) + >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2)) + tensor(1.00000e-07 * + 2.8312) +""") + add_docstr(torch.masked_select, r""" masked_select(input, mask, out=None) -> Tensor @@ -5522,30 +5548,6 @@ Example:: [ 0., 0., 0.]]) """.format(**factory_like_common_args)) -add_docstr(torch.btrisolve, - r""" -btrisolve(b, LU_data, LU_pivots) -> Tensor - -Batch LU solve. - -Returns the LU solve of the linear system :math:`Ax = b`. - -Arguments: - b (Tensor): the RHS tensor - LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`. - LU_pivots (IntTensor): the pivots of the LU factorization - -Example:: - - >>> A = torch.randn(2, 3, 3) - >>> b = torch.randn(2, 3) - >>> A_LU = torch.lu(A) - >>> x = torch.btrisolve(b, *A_LU) - >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2)) - tensor(1.00000e-07 * - 2.8312) -""") - add_docstr(torch.empty, r""" empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor diff --git a/torch/functional.py b/torch/functional.py index 3676ce1..4a13658 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -5,10 +5,11 @@ from itertools import product import warnings __all__ = [ - 'btriunpack', 'broadcast_tensors', 'btrifact', 'btrifact_with_info', + 'btrisolve', + 'btriunpack', 'cartesian_prod', 'chain_matmul', 'einsum', @@ -821,6 +822,22 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): unpack_data=unpack_data, unpack_pivots=unpack_pivots) +def btrisolve(b, LU_data, LU_pivots, out=None): + r"""Solves the system of equations :math:`Ax = b` using the partially pivoted LU + factorization of :math:`A` given by :attr:`LU_data` and :attr:`LU_pivots`. + + For more information regarding :func:`torch.btrisolve`, please check + :func:`torch.lu_solve`. + + .. warning:: + :func:`torch.btrisolve` is deprecated in favour of :func:`torch.lu_solve` and will be + removed in the next release. Please use :func:`torch.lu_solve` instead. + """ + warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be " + "removed in the next release. Please use torch.lu_solve instead.", stacklevel=2) + return torch.lu_solve(b, LU_data=LU_data, LU_pivots=LU_pivots, out=out) + + def lu(A, pivot=True, get_infos=False, out=None): r"""Computes the LU factorization of a square matrix or batches of square matrices :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`. diff --git a/torch/tensor.py b/torch/tensor.py index 4ac4c6b..a788c50 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -296,10 +296,17 @@ class Tensor(torch._C._TensorBase): def btrifact_with_info(self, pivot=True): r"""See :func:`torch.lu`""" warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the " - "and will be removed in the next release. Please use torch.lu with the " - "get_infos argument set to True instead.", stacklevel=2) + "get_infos argument and will be removed in the next release. Please use " + "torch.lu with the get_infos argument set to True instead.", stacklevel=2) return torch._lu_with_info(self, pivot=pivot, check_errors=False) + def btrisolve(self, LU_data, LU_pivots): + r"""See :func:`torch.lu_solve`""" + warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be " + "removed in the next release. Please use torch.lu_solve instead.", + stacklevel=2) + return super(Tensor, self).lu_solve(LU_data=LU_data, LU_pivots=LU_pivots) + def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa -- 2.7.4