std::tuple<Tensor,Tensor> geqrf() const;
Tensor orgqr(const Tensor & input2) const;
Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
- Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
+ Tensor lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const;
Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
Tensor lgamma() const;
Tensor digamma() const;
inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
return dispatch_type().ormqr(*this, input2, input3, left, transpose);
}
-inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
- return dispatch_type().btrisolve(*this, LU_data, LU_pivots);
+inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const {
+ return dispatch_type().lu_solve(*this, LU_data, LU_pivots);
}
inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const {
return dispatch_type().multinomial(*this, num_samples, replacement, generator);
virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
- virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
+ virtual Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
virtual Tensor lgamma(const Tensor & self) const = 0;
virtual Tensor digamma(const Tensor & self) const = 0;
_(aten, blackman_window) \
_(aten, bmm) \
_(aten, broadcast_tensors) \
-_(aten, btrisolve) \
_(aten, cartesian_prod) \
_(aten, cat) \
_(aten, cauchy) \
_(aten, lstm) \
_(aten, lstm_cell) \
_(aten, lt) \
+_(aten, lu_solve) \
_(aten, margin_ranking_loss) \
_(aten, masked_fill) \
_(aten, masked_scatter) \
return at::legacy::th::_th_ormqr(self, input2, input3, left, transpose);
}
-std::tuple<Tensor,Tensor> _multinomial_alias_setup(const Tensor & probs) {
- return at::legacy::th::_th_multinomial_alias_setup(probs);
+Tensor & lu_solve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
+ return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots);
}
-Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) {
- return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator);
+Tensor lu_solve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
+ return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots);
}
-Tensor & btrisolve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
- return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots);
+std::tuple<Tensor,Tensor> _multinomial_alias_setup(const Tensor & probs) {
+ return at::legacy::th::_th_multinomial_alias_setup(probs);
}
-Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
- return at::legacy::th::_th_btrisolve(self, LU_data, LU_pivots);
+Tensor _multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, Generator * generator) {
+ return at::legacy::th::_th_multinomial_alias_draw(q, J, num_samples, generator);
}
Tensor & multinomial_out(Tensor & result, const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) {
CPU: _lu_with_info_cpu
CUDA: _lu_with_info_cuda
-- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
+- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
-- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
+- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
matches_jit_signature: True
variants: method, function
.. automethod:: lt
.. automethod:: lt_
.. automethod:: lu
+ .. automethod:: lu_solve
.. automethod:: map_
.. automethod:: masked_scatter_
.. automethod:: masked_scatter
.. autofunction:: logdet
.. autofunction:: slogdet
.. autofunction:: lu
+.. autofunction:: lu_solve
.. autofunction:: lu_unpack
.. autofunction:: matmul
.. autofunction:: matrix_power
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
- def test_btrisolve(self):
- _TestTorchMixin._test_btrisolve(self, lambda t: t.cuda())
+ def test_lu_solve(self):
+ _TestTorchMixin._test_lu_solve(self, lambda t: t.cuda())
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
self._test_lu(self, lambda t: t)
@staticmethod
- def _test_btrisolve(self, cast):
+ def _test_lu_solve(self, cast):
a = torch.FloatTensor((((1.3722, -0.9020),
(1.8849, 1.9169)),
((0.7187, -1.1695),
a, b = cast(a), cast(b)
LU_data, pivots, info = a.lu(get_infos=True)
self.assertEqual(info.abs().sum(), 0)
- x = torch.btrisolve(b, LU_data, pivots)
+ x = torch.lu_solve(b, LU_data, pivots)
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
self.assertEqual(b_, b)
@skipIfNoLapack
- def test_btrisolve(self):
- self._test_btrisolve(self, lambda t: t)
+ def test_lu_solve(self):
+ self._test_lu_solve(self, lambda t: t)
@staticmethod
def _test_lu_unpack(self, cast):
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
-- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
- self: not_implemented("btrisolve")
-
- name: cat(TensorList tensors, int64_t dim)
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
self: not_implemented("lu_with_info")
+- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots)
+ self: not_implemented("lu_solve")
+
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
self: grad.clone().masked_fill_(mask, 0)
See :func:`torch.bmm`
""")
-add_docstr_all('btrisolve',
- r"""
-btrisolve(LU_data, LU_pivots) -> Tensor
-
-See :func:`torch.btrisolve`
-""")
-
add_docstr_all('cauchy_',
r"""
cauchy_(median=0, sigma=1, *, generator=None) -> Tensor
In-place version of :meth:`~Tensor.lt`
""")
+add_docstr_all('lu_solve',
+ r"""
+lu_solve(LU_data, LU_pivots) -> Tensor
+
+See :func:`torch.lu_solve`
+""")
+
add_docstr_all('map_',
r"""
map_(tensor, callable)
[ 1, 0]], dtype=torch.uint8)
""")
+add_docstr(torch.lu_solve,
+ r"""
+lu_solve(b, LU_data, LU_pivots, out=None) -> Tensor
+
+Batch LU solve.
+
+Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted
+LU factorization of A from :meth:`torch.lu`.
+
+Arguments:
+ b (Tensor): the RHS tensor
+ LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
+ LU_pivots (IntTensor): the pivots of the LU factorization
+ out (Tensor, optional): the optional output tensor
+
+Example::
+
+ >>> A = torch.randn(2, 3, 3)
+ >>> b = torch.randn(2, 3)
+ >>> A_LU = torch.lu(A)
+ >>> x = torch.lu_solve(b, *A_LU)
+ >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
+ tensor(1.00000e-07 *
+ 2.8312)
+""")
+
add_docstr(torch.masked_select,
r"""
masked_select(input, mask, out=None) -> Tensor
[ 0., 0., 0.]])
""".format(**factory_like_common_args))
-add_docstr(torch.btrisolve,
- r"""
-btrisolve(b, LU_data, LU_pivots) -> Tensor
-
-Batch LU solve.
-
-Returns the LU solve of the linear system :math:`Ax = b`.
-
-Arguments:
- b (Tensor): the RHS tensor
- LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
- LU_pivots (IntTensor): the pivots of the LU factorization
-
-Example::
-
- >>> A = torch.randn(2, 3, 3)
- >>> b = torch.randn(2, 3)
- >>> A_LU = torch.lu(A)
- >>> x = torch.btrisolve(b, *A_LU)
- >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
- tensor(1.00000e-07 *
- 2.8312)
-""")
-
add_docstr(torch.empty,
r"""
empty(*sizes, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
import warnings
__all__ = [
- 'btriunpack',
'broadcast_tensors',
'btrifact',
'btrifact_with_info',
+ 'btrisolve',
+ 'btriunpack',
'cartesian_prod',
'chain_matmul',
'einsum',
unpack_data=unpack_data, unpack_pivots=unpack_pivots)
+def btrisolve(b, LU_data, LU_pivots, out=None):
+ r"""Solves the system of equations :math:`Ax = b` using the partially pivoted LU
+ factorization of :math:`A` given by :attr:`LU_data` and :attr:`LU_pivots`.
+
+ For more information regarding :func:`torch.btrisolve`, please check
+ :func:`torch.lu_solve`.
+
+ .. warning::
+ :func:`torch.btrisolve` is deprecated in favour of :func:`torch.lu_solve` and will be
+ removed in the next release. Please use :func:`torch.lu_solve` instead.
+ """
+ warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+ "removed in the next release. Please use torch.lu_solve instead.", stacklevel=2)
+ return torch.lu_solve(b, LU_data=LU_data, LU_pivots=LU_pivots, out=out)
+
+
def lu(A, pivot=True, get_infos=False, out=None):
r"""Computes the LU factorization of a square matrix or batches of square matrices
:attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
def btrifact_with_info(self, pivot=True):
r"""See :func:`torch.lu`"""
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
- "and will be removed in the next release. Please use torch.lu with the "
- "get_infos argument set to True instead.", stacklevel=2)
+ "get_infos argument and will be removed in the next release. Please use "
+ "torch.lu with the get_infos argument set to True instead.", stacklevel=2)
return torch._lu_with_info(self, pivot=pivot, check_errors=False)
+ def btrisolve(self, LU_data, LU_pivots):
+ r"""See :func:`torch.lu_solve`"""
+ warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
+ "removed in the next release. Please use torch.lu_solve instead.",
+ stacklevel=2)
+ return super(Tensor, self).lu_solve(LU_data=LU_data, LU_pivots=LU_pivots)
+
def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa