Rename btrisolve to lu_solve (#18726)
authorVishwak Srinivasan <cs15btech11043@iith.ac.in>
Tue, 9 Apr 2019 22:15:06 +0000 (15:15 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 9 Apr 2019 22:21:24 +0000 (15:21 -0700)
Summary:
Changelog:
- Rename `btrisolve` to `lu_solve` to remain consistent with names of solve methods (`cholesky_solve`, `triangular_solve`, `solve`)
- Fix all callsites
- Rename all tests
- Create a tentative alias for `lu_solve` under the name `btrisolve` and add a deprecation warning to not promote usage
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18726

Differential Revision: D14726237

Pulled By: zou3519

fbshipit-source-id: bf25f6c79062183a4153015e0ec7ebab2c8b986b

15 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_cuda.py
test/test_torch.py
tools/autograd/derivatives.yaml
torch/_tensor_docs.py
torch/_torch_docs.py
torch/functional.py
torch/tensor.py

index 9236074..009ad2e 100644 (file)
@@ -712,7 +712,7 @@ class CAFFE2_API Tensor {
   std::tuple<Tensor,Tensor> geqrf() const;
   Tensor orgqr(const Tensor & input2) const;
   Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
-  Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
+  Tensor lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const;
   Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
   Tensor lgamma() const;
   Tensor digamma() const;
index 7509f59..065af04 100644 (file)
@@ -1192,8 +1192,8 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const {
 inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
     return dispatch_type().ormqr(*this, input2, input3, left, transpose);
 }
-inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
-    return dispatch_type().btrisolve(*this, LU_data, LU_pivots);
+inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const {
+    return dispatch_type().lu_solve(*this, LU_data, LU_pivots);
 }
 inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const {
     return dispatch_type().multinomial(*this, num_samples, replacement, generator);
index ad6000e..97871be 100644 (file)
@@ -587,7 +587,7 @@ struct CAFFE2_API Type {
   virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
   virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
   virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
-  virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
+  virtual Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
   virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
   virtual Tensor lgamma(const Tensor & self) const = 0;
   virtual Tensor digamma(const Tensor & self) const = 0;
index ca3fe63..7297689 100644 (file)
@@ -225,7 +225,6 @@ _(aten, bincount) \
 _(aten, blackman_window) \
 _(aten, bmm) \
 _(aten, broadcast_tensors) \
-_(aten, btrisolve) \
 _(aten, cartesian_prod) \
 _(aten, cat) \
 _(aten, cauchy) \
@@ -416,6 +415,7 @@ _(aten, logsumexp) \
 _(aten, lstm) \
 _(aten, lstm_cell) \
 _(aten, lt) \
+_(aten, lu_solve) \
 _(aten, margin_ranking_loss) \
 _(aten, masked_fill) \
 _(aten, masked_scatter) \
index 6315d39..60dce0b 100644 (file)
@@ -508,20 +508,20 @@ Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3,
   return at::legacy::th::_th_ormqr(self, input2, input3, left, transpose);
 }
 
-std::tuple<Tensor,Tensor> _multinomial_alias_setup(const Tensor & probs) {
-  return at::legacy::th::_th_multinomial_alias_setup(probs);
+Tensor & lu_solve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
+  return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots);
 }
 
-Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) {
-  return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator);
+Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
+  return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots);
 }
 
-Tensor & btrisolve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
-  return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots);
+std::tuple<Tensor,Tensor> _multinomial_alias_setup(const Tensor & probs) {
+  return at::legacy::th::_th_multinomial_alias_setup(probs);
 }
 
-Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
-  return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots);
+Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) {
+  return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator);
 }
 
 Tensor & multinomial_out(Tensor & result, const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) {
index ed3226d..a3406a6 100644 (file)
     CPU: _lu_with_info_cpu
     CUDA: _lu_with_info_cuda
 
-- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
+- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
+- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
   matches_jit_signature: True
   variants: method, function
 
index de5a28d..af10e8e 100644 (file)
@@ -308,6 +308,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: lt
    .. automethod:: lt_
    .. automethod:: lu
+   .. automethod:: lu_solve
    .. automethod:: map_
    .. automethod:: masked_scatter_
    .. automethod:: masked_scatter
index 85bf6d5..039fd0f 100644 (file)
@@ -317,6 +317,7 @@ BLAS and LAPACK Operations
 .. autofunction:: logdet
 .. autofunction:: slogdet
 .. autofunction:: lu
+.. autofunction:: lu_solve
 .. autofunction:: lu_unpack
 .. autofunction:: matmul
 .. autofunction:: matrix_power
index f10ebde..056652d 100644 (file)
@@ -2362,8 +2362,8 @@ class TestCuda(TestCase):
 
     @skipIfRocm
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_btrisolve(self):
-        _TestTorchMixin._test_btrisolve(self, lambda t: t.cuda())
+    def test_lu_solve(self):
+        _TestTorchMixin._test_lu_solve(self, lambda t: t.cuda())
 
     @skipIfRocm
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
index 97bb227..6903fdd 100644 (file)
@@ -1773,7 +1773,7 @@ class _TestTorchMixin(object):
         self._test_lu(self, lambda t: t)
 
     @staticmethod
-    def _test_btrisolve(self, cast):
+    def _test_lu_solve(self, cast):
         a = torch.FloatTensor((((1.3722, -0.9020),
                                 (1.8849, 1.9169)),
                                ((0.7187, -1.1695),
@@ -1786,13 +1786,13 @@ class _TestTorchMixin(object):
         a, b = cast(a), cast(b)
         LU_data, pivots, info = a.lu(get_infos=True)
         self.assertEqual(info.abs().sum(), 0)
-        x = torch.btrisolve(b, LU_data, pivots)
+        x = torch.lu_solve(b, LU_data, pivots)
         b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
         self.assertEqual(b_, b)
 
     @skipIfNoLapack
-    def test_btrisolve(self):
-        self._test_btrisolve(self, lambda t: t)
+    def test_lu_solve(self):
+        self._test_lu_solve(self, lambda t: t)
 
     @staticmethod
     def _test_lu_unpack(self, cast):
index 1ba68ce..ab6bb84 100644 (file)
   self: grad.bmm(mat2.transpose(1, 2))
   mat2: self.transpose(1, 2).bmm(grad)
 
-- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
-  self: not_implemented("btrisolve")
-
 - name: cat(TensorList tensors, int64_t dim)
   tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
 
 - name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
   self: not_implemented("lu_with_info")
 
+- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots)
+  self: not_implemented("lu_solve")
+
 - name: masked_fill_(Tensor self, Tensor mask, Scalar value)
   self: grad.clone().masked_fill_(mask, 0)
 
index e9f5ba6..262998c 100644 (file)
@@ -486,13 +486,6 @@ bmm(batch2) -> Tensor
 See :func:`torch.bmm`
 """)
 
-add_docstr_all('btrisolve',
-               r"""
-btrisolve(LU_data, LU_pivots) -> Tensor
-
-See :func:`torch.btrisolve`
-""")
-
 add_docstr_all('cauchy_',
                r"""
 cauchy_(median=0, sigma=1, *, generator=None) -> Tensor
@@ -1419,6 +1412,13 @@ lt_(other) -> Tensor
 In-place version of :meth:`~Tensor.lt`
 """)
 
+add_docstr_all('lu_solve',
+               r"""
+lu_solve(LU_data, LU_pivots) -> Tensor
+
+See :func:`torch.lu_solve`
+""")
+
 add_docstr_all('map_',
                r"""
 map_(tensor, callable)
index 9c3bb21..522fcca 100644 (file)
@@ -2564,6 +2564,32 @@ Example::
             [ 1,  0]], dtype=torch.uint8)
 """)
 
+add_docstr(torch.lu_solve,
+           r"""
+lu_solve(b, LU_data, LU_pivots, out=None) -> Tensor
+
+Batch LU solve.
+
+Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
+LU factorization of A from :meth:`torch.lu`.
+
+Arguments:
+    b (Tensor): the RHS tensor
+    LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
+    LU_pivots (IntTensor): the pivots of the LU factorization
+    out (Tensor, optional): the optional output tensor
+
+Example::
+
+    >>> A = torch.randn(2, 3, 3)
+    >>> b = torch.randn(2, 3)
+    >>> A_LU = torch.lu(A)
+    >>> x = torch.lu_solve(b, *A_LU)
+    >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
+    tensor(1.00000e-07 *
+           2.8312)
+""")
+
 add_docstr(torch.masked_select,
            r"""
 masked_select(input, mask, out=None) -> Tensor
@@ -5522,30 +5548,6 @@ Example::
             [ 0.,  0.,  0.]])
 """.format(**factory_like_common_args))
 
-add_docstr(torch.btrisolve,
-           r"""
-btrisolve(b, LU_data, LU_pivots) -> Tensor
-
-Batch LU solve.
-
-Returns the LU solve of the linear system :math:`Ax = b`.
-
-Arguments:
-    b (Tensor): the RHS tensor
-    LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
-    LU_pivots (IntTensor): the pivots of the LU factorization
-
-Example::
-
-    >>> A = torch.randn(2, 3, 3)
-    >>> b = torch.randn(2, 3)
-    >>> A_LU = torch.lu(A)
-    >>> x = torch.btrisolve(b, *A_LU)
-    >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
-    tensor(1.00000e-07 *
-           2.8312)
-""")
-
 add_docstr(torch.empty,
            r"""
 empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
index 3676ce1..4a13658 100644 (file)
@@ -5,10 +5,11 @@ from itertools import product
 import warnings
 
 __all__ = [
-    'btriunpack',
     'broadcast_tensors',
     'btrifact',
     'btrifact_with_info',
+    'btrisolve',
+    'btriunpack',
     'cartesian_prod',
     'chain_matmul',
     'einsum',
@@ -821,6 +822,22 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
                      unpack_data=unpack_data, unpack_pivots=unpack_pivots)
 
 
+def btrisolve(b, LU_data, LU_pivots, out=None):
+    r"""Solves the system of equations :math:`Ax = b` using the partially pivoted LU
+    factorization of :math:`A` given by :attr:`LU_data` and :attr:`LU_pivots`.
+
+    For more information regarding :func:`torch.btrisolve`, please check
+    :func:`torch.lu_solve`.
+
+    .. warning::
+        :func:`torch.btrisolve` is deprecated in favour of :func:`torch.lu_solve` and will be
+        removed in the next release. Please use :func:`torch.lu_solve` instead.
+    """
+    warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+                  "removed in the next release. Please use torch.lu_solve instead.", stacklevel=2)
+    return torch.lu_solve(b, LU_data=LU_data, LU_pivots=LU_pivots, out=out)
+
+
 def lu(A, pivot=True, get_infos=False, out=None):
     r"""Computes the LU factorization of a square matrix or batches of square matrices
     :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
index 4ac4c6b..a788c50 100644 (file)
@@ -296,10 +296,17 @@ class Tensor(torch._C._TensorBase):
     def btrifact_with_info(self, pivot=True):
         r"""See :func:`torch.lu`"""
         warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
-                      "and will be removed in the next release. Please use torch.lu with the "
-                      "get_infos argument set to True instead.", stacklevel=2)
+                      "get_infos argument and will be removed in the next release. Please use "
+                      "torch.lu with the get_infos argument set to True instead.", stacklevel=2)
         return torch._lu_with_info(self, pivot=pivot, check_errors=False)
 
+    def btrisolve(self, LU_data, LU_pivots):
+        r"""See :func:`torch.lu_solve`"""
+        warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+                      "removed in the next release. Please use torch.lu_solve instead.",
+                      stacklevel=2)
+        return super(Tensor, self).lu_solve(LU_data=LU_data, LU_pivots=LU_pivots)
+
     def lu(self, pivot=True, get_infos=False):
         r"""See :func:`torch.lu`"""
         # If get_infos is True, then we don't need to check for errors and vice versa