From e73be58ff7f9df478cdb57fad6102a2dcf5dc181 Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Fri, 29 Mar 2019 12:58:23 -0700 Subject: [PATCH] Rename `btriunpack` to `lu_unpack` (#18529) Summary: Changelog: - Renames `btriunpack` to `lu_unpack` to remain consistent with the `lu` function interface. - Rename all relevant tests, fix callsites - Create a tentative alias for `lu_unpack` under the name `btriunpack` and add a deprecation warning to not promote usage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18529 Differential Revision: D14683161 Pulled By: soumith fbshipit-source-id: 994287eaa15c50fd74c2f1c7646edfc61e8099b1 --- docs/source/torch.rst | 1 + test/test_cuda.py | 4 ++-- test/test_torch.py | 14 +++++++------- torch/functional.py | 20 ++++++++++++++++++-- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 30a09dc..b34f188 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -316,6 +316,7 @@ BLAS and LAPACK Operations .. autofunction:: logdet .. autofunction:: slogdet .. autofunction:: lu +.. autofunction:: lu_unpack .. autofunction:: matmul .. autofunction:: matrix_power .. autofunction:: matrix_rank diff --git a/test/test_cuda.py b/test/test_cuda.py index e919cf9..3d42369 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2371,8 +2371,8 @@ class TestCuda(TestCase): @skipIfRocm @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_btriunpack(self): - _TestTorchMixin._test_btriunpack(self, lambda t: t.cuda()) + def test_lu_unpack(self): + _TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda()) def test_dim_reduction(self): _TestTorchMixin._test_dim_reduction(self, lambda t: t.cuda()) diff --git a/test/test_torch.py b/test/test_torch.py index bd0ac1a..7ae8627 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1763,7 +1763,7 @@ class _TestTorchMixin(object): a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32))) self.assertEqual(info_, info_nopiv) - P, L, U = torch.btriunpack(a_LU, pivots) + P, L, U = torch.lu_unpack(a_LU, pivots) self.assertEqual(P.matmul(L.matmul(U)), a) for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]): @@ -1807,11 +1807,11 @@ class _TestTorchMixin(object): self._test_btrisolve(self, lambda t: t) @staticmethod - def _test_btriunpack(self, cast): + def _test_lu_unpack(self, cast): def run_test(shape, cast): a = cast(torch.randn(*shape)) a_lu, p = torch.lu(a) - p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p) + p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) run_test((3, 3), cast) @@ -1820,8 +1820,8 @@ class _TestTorchMixin(object): run_test((7, 5, 3, 3, 3), cast) @skipIfNoLapack - def test_btriunpack(self): - self._test_btriunpack(self, lambda t: t) + def test_lu_unpack(self): + self._test_lu_unpack(self, lambda t: t) def test_bmm(self): num_batches = 10 @@ -7451,10 +7451,10 @@ class _TestTorchMixin(object): self.assertEqual(ret1.V, ret[2]) # test gesv - ret = a.gesv(a) + ret = a.solve(a) self.assertEqual(ret.solution, ret[0]) self.assertEqual(ret.LU, ret[1]) - ret1 = torch.gesv(a, a, out=tuple(ret)) + ret1 = torch.solve(a, a, out=tuple(ret)) self.assertEqual(ret1.solution, ret1[0]) self.assertEqual(ret1.LU, ret1[1]) self.assertEqual(ret1.solution, ret[0]) diff --git a/torch/functional.py b/torch/functional.py index 2fcb46a..3676ce1 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -16,6 +16,7 @@ __all__ = [ 'isfinite', 'isinf', 'lu', + 'lu_unpack', 'norm', 'meshgrid', 'potrf', @@ -83,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0): return tensor.split(split_size_or_sections, dim) -def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): +def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): r"""Unpacks the data and pivots from a LU factorization of a tensor. Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``. @@ -98,7 +99,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): >>> A = torch.randn(2, 3, 3) >>> A_LU, pivots = A.lu() - >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots) + >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots) >>> >>> # can recover A from factorization >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) @@ -805,6 +806,21 @@ def btrifact_with_info(A, pivot=True, out=None): return lu(A, pivot=pivot, get_infos=True, out=out) +def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): + r"""Unpacks the data and pivots from a LU factorization of a tensor. + + For more information regarding :func:`torch.btriunpack`, please check :func:`torch.lu_unpack`. + + .. warning:: + :func:`torch.btriunpack` is deprecated in favour of :func:`torch.lu_unpack` and will be + removed in the next release. Please use :func:`torch.lu_unpack` instead. + """ + warnings.warn("torch.btriunpack is deprecated in favour of torch.lu_unpack and will be " + "removed in the next release. Please use torch.lu_unpack instead.", stacklevel=2) + return lu_unpack(LU_data=LU_data, LU_pivots=LU_pivots, + unpack_data=unpack_data, unpack_pivots=unpack_pivots) + + def lu(A, pivot=True, get_infos=False, out=None): r"""Computes the LU factorization of a square matrix or batches of square matrices :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`. -- 2.7.4