Rename `btriunpack` to `lu_unpack` (#18529)
authorVishwak Srinivasan <cs15btech11043@iith.ac.in>
Fri, 29 Mar 2019 19:58:23 +0000 (12:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 20:01:30 +0000 (13:01 -0700)
Summary:
Changelog:
- Renames `btriunpack` to `lu_unpack` to remain consistent with the `lu` function interface.
- Rename all relevant tests, fix callsites
- Create a tentative alias for `lu_unpack` under the name `btriunpack` and add a deprecation warning to not promote usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18529

Differential Revision: D14683161

Pulled By: soumith

fbshipit-source-id: 994287eaa15c50fd74c2f1c7646edfc61e8099b1

docs/source/torch.rst
test/test_cuda.py
test/test_torch.py
torch/functional.py

index 30a09dc..b34f188 100644 (file)
@@ -316,6 +316,7 @@ BLAS and LAPACK Operations
 .. autofunction:: logdet
 .. autofunction:: slogdet
 .. autofunction:: lu
+.. autofunction:: lu_unpack
 .. autofunction:: matmul
 .. autofunction:: matrix_power
 .. autofunction:: matrix_rank
index e919cf9..3d42369 100644 (file)
@@ -2371,8 +2371,8 @@ class TestCuda(TestCase):
 
     @skipIfRocm
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_btriunpack(self):
-        _TestTorchMixin._test_btriunpack(self, lambda t: t.cuda())
+    def test_lu_unpack(self):
+        _TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda())
 
     def test_dim_reduction(self):
         _TestTorchMixin._test_dim_reduction(self, lambda t: t.cuda())
index bd0ac1a..7ae8627 100644 (file)
@@ -1763,7 +1763,7 @@ class _TestTorchMixin(object):
                 a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True)
                 self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32)))
                 self.assertEqual(info_, info_nopiv)
-            P, L, U = torch.btriunpack(a_LU, pivots)
+            P, L, U = torch.lu_unpack(a_LU, pivots)
             self.assertEqual(P.matmul(L.matmul(U)), a)
 
         for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]):
@@ -1807,11 +1807,11 @@ class _TestTorchMixin(object):
         self._test_btrisolve(self, lambda t: t)
 
     @staticmethod
-    def _test_btriunpack(self, cast):
+    def _test_lu_unpack(self, cast):
         def run_test(shape, cast):
             a = cast(torch.randn(*shape))
             a_lu, p = torch.lu(a)
-            p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
+            p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p)
             self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
 
         run_test((3, 3), cast)
@@ -1820,8 +1820,8 @@ class _TestTorchMixin(object):
         run_test((7, 5, 3, 3, 3), cast)
 
     @skipIfNoLapack
-    def test_btriunpack(self):
-        self._test_btriunpack(self, lambda t: t)
+    def test_lu_unpack(self):
+        self._test_lu_unpack(self, lambda t: t)
 
     def test_bmm(self):
         num_batches = 10
@@ -7451,10 +7451,10 @@ class _TestTorchMixin(object):
         self.assertEqual(ret1.V, ret[2])
 
         # test gesv
-        ret = a.gesv(a)
+        ret = a.solve(a)
         self.assertEqual(ret.solution, ret[0])
         self.assertEqual(ret.LU, ret[1])
-        ret1 = torch.gesv(a, a, out=tuple(ret))
+        ret1 = torch.solve(a, a, out=tuple(ret))
         self.assertEqual(ret1.solution, ret1[0])
         self.assertEqual(ret1.LU, ret1[1])
         self.assertEqual(ret1.solution, ret[0])
index 2fcb46a..3676ce1 100644 (file)
@@ -16,6 +16,7 @@ __all__ = [
     'isfinite',
     'isinf',
     'lu',
+    'lu_unpack',
     'norm',
     'meshgrid',
     'potrf',
@@ -83,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0):
     return tensor.split(split_size_or_sections, dim)
 
 
-def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
+def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
     r"""Unpacks the data and pivots from a LU factorization of a tensor.
 
     Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
@@ -98,7 +99,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
 
         >>> A = torch.randn(2, 3, 3)
         >>> A_LU, pivots = A.lu()
-        >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots)
+        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
         >>>
         >>> # can recover A from factorization
         >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
@@ -805,6 +806,21 @@ def btrifact_with_info(A, pivot=True, out=None):
     return lu(A, pivot=pivot, get_infos=True, out=out)
 
 
+def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
+    r"""Unpacks the data and pivots from a LU factorization of a tensor.
+
+    For more information regarding :func:`torch.btriunpack`, please check :func:`torch.lu_unpack`.
+
+    .. warning::
+        :func:`torch.btriunpack` is deprecated in favour of :func:`torch.lu_unpack` and will be
+        removed in the next release. Please use :func:`torch.lu_unpack` instead.
+    """
+    warnings.warn("torch.btriunpack is deprecated in favour of torch.lu_unpack and will be "
+                  "removed in the next release. Please use torch.lu_unpack instead.", stacklevel=2)
+    return lu_unpack(LU_data=LU_data, LU_pivots=LU_pivots,
+                     unpack_data=unpack_data, unpack_pivots=unpack_pivots)
+
+
 def lu(A, pivot=True, get_infos=False, out=None):
     r"""Computes the LU factorization of a square matrix or batches of square matrices
     :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.