From 291746f11047361100102577ce7d1cfa1833be50 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Thu, 21 Mar 2019 14:18:38 -0700 Subject: [PATCH] Rename trtrs to triangular_solve (#18213) Summary: Changelog: - Renames `trtrs` to `triangular_solve` to remain consistent with `cholesky_solve` and `solve`. - Rename all tests, fix callsites - Create a tentative alias for `triangular_solve` under the name `trtrs`, and add a deprecation warning to not promote usage. - Move `isnan` to _torch_docs.py - Remove unnecessary imports Pull Request resolved: https://github.com/pytorch/pytorch/pull/18213 Differential Revision: D14566902 Pulled By: ezyang fbshipit-source-id: 544f57c29477df391bacd5de700bed1add456d3f --- aten/src/ATen/core/Tensor.h | 2 +- aten/src/ATen/core/TensorMethods.h | 4 +- aten/src/ATen/core/Type.h | 2 +- aten/src/ATen/core/aten_interned_strings.h | 2 +- aten/src/ATen/native/BatchLinearAlgebra.cpp | 53 ++++----- aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 31 +++--- aten/src/ATen/native/native_functions.yaml | 10 +- docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/test_autograd.py | 4 +- test/test_cuda.py | 12 +- test/test_torch.py | 74 ++++++------- tools/autograd/derivatives.yaml | 6 +- tools/autograd/gen_python_functions.py | 2 +- tools/autograd/templates/Functions.cpp | 4 +- torch/_tensor_docs.py | 14 +-- torch/_torch_docs.py | 122 ++++++++++++--------- torch/distributions/kl.py | 10 +- torch/distributions/lowrank_multivariate_normal.py | 2 +- torch/distributions/multivariate_normal.py | 4 +- torch/functional.py | 40 +++---- torch/tensor.py | 7 ++ 22 files changed, 206 insertions(+), 201 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 8ae46dd..24b64b9 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -687,7 +687,7 @@ class CAFFE2_API Tensor { Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const; std::tuple gels(const Tensor & A) const; - std::tuple trtrs(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const; + std::tuple triangular_solve(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const; std::tuple symeig(bool eigenvectors=false, bool upper=true) const; std::tuple eig(bool eigenvectors=false) const; std::tuple svd(bool some=true, bool compute_uv=true) const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 7d5beab..e153f06 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1132,8 +1132,8 @@ inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Sc inline std::tuple Tensor::gels(const Tensor & A) const { return type().gels(*this, A); } -inline std::tuple Tensor::trtrs(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { - return type().trtrs(*this, A, upper, transpose, unitriangular); +inline std::tuple Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { + return type().triangular_solve(*this, A, upper, transpose, unitriangular); } inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) const { return type().symeig(*this, eigenvectors, upper); diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 37a784a..5b7a5d7 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -565,7 +565,7 @@ struct CAFFE2_API Type { virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0; virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0; virtual std::tuple gels(const Tensor & self, const Tensor & A) const = 0; - virtual std::tuple trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0; + virtual std::tuple triangular_solve(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0; virtual std::tuple symeig(const Tensor & self, bool eigenvectors, bool upper) const = 0; virtual std::tuple eig(const Tensor & self, bool eigenvectors) const = 0; virtual std::tuple svd(const Tensor & self, bool some, bool compute_uv) const = 0; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index ba6b64d..b71ffe8 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -671,10 +671,10 @@ _(aten, to_dense) \ _(aten, topk) \ _(aten, trace) \ _(aten, transpose) \ +_(aten, triangular_solve) \ _(aten, tril) \ _(aten, triplet_margin_loss) \ _(aten, triu) \ -_(aten, trtrs) \ _(aten, trunc) \ _(aten, type_as) \ _(aten, unbind) \ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 562e154..46f6d47 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -71,8 +71,8 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) { } template -void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) { - AT_ERROR("trtrs only takes float or double Tensors"); +void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) { + AT_ERROR("triangular_solve only takes float or double Tensors"); } #ifdef USE_LAPACK @@ -116,11 +116,11 @@ template<> void lapackCholesky(char uplo, int n, float *a, int lda, int * spotrf_(&uplo, &n, a, &lda, info); } -template<> void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) { +template<> void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) { dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); } -template<> void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) { +template<> void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) { strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); } #endif @@ -333,9 +333,6 @@ Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) { } Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) { - AT_CHECK(self.dim() == 2 && A.dim() == 2, - "torch.cholesky_solve() with the `out` keyword does not support batching. " - "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); Tensor result_tmp; result_tmp = at::_cholesky_solve_helper(self, A, upper); result.resize_as_(result_tmp).copy_(result_tmp); @@ -641,12 +638,12 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) { return result; } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector& infos) { +static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) { #ifndef USE_LAPACK - AT_ERROR("trtrs: LAPACK library not found in compilation"); + AT_ERROR("triangular_solve: LAPACK library not found in compilation"); #else char uplo = upper ? 'U' : 'L'; char trans = transpose ? 'T' : 'N'; @@ -659,8 +656,7 @@ static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool u int info; if (b.dim() == 2) { - lapackTrtrs(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info); - infos[0] = info; + lapackTriangularSolve(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info); } else { auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); @@ -668,49 +664,38 @@ static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool u for (int64_t i = 0; i < batch_size; i++) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; - lapackTrtrs(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info); - infos[i] = info; - if (info != 0) { - return; - } + lapackTriangularSolve(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info); } } #endif } -std::tuple _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { +std::tuple _triangular_solve_helper_cpu(const Tensor& self, const Tensor& A, + bool upper, bool transpose, bool unitriangular) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{ - apply_trtrs(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cpu", [&]{ + apply_triangular_solve(self_working_copy, A_working_copy, upper, transpose, unitriangular); }); - if (self.dim() > 2) { - batchCheckErrors(infos, "trtrs_cpu"); - } else { - singleCheckErrors(infos[0], "trtrs_cpu"); - } return std::tuple(self_working_copy, A_working_copy); } // Supports arbitrary batch dimensions for self and A -std::tuple trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { +std::tuple triangular_solve(const Tensor& self, const Tensor& A, + bool upper, bool transpose, bool unitriangular) { AT_CHECK(self.dim() >= 2, "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); AT_CHECK(A.dim() >= 2, "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead"); Tensor self_broadcasted, A_broadcasted; std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A); - return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular); + return at::_triangular_solve_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular); } -std::tuple trtrs_out(Tensor& result, Tensor& clone_A, - const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { - AT_CHECK(self.dim() == 2 && A.dim() == 2, - "torch.trtrs() with the `out` keyword does not support batching. " - "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); +std::tuple triangular_solve_out(Tensor& result, Tensor& clone_A, const Tensor& self, const Tensor& A, + bool upper, bool transpose, bool unitriangular) { Tensor result_tmp, clone_A_tmp; - std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular); + std::tie(result_tmp, clone_A_tmp) = at::_triangular_solve_helper(self, A, upper, transpose, unitriangular); result.resize_as_(result_tmp).copy_(result_tmp); clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp); return std::tuple(result, clone_A); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 812f1ea..f401995 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -86,18 +86,18 @@ void magmaCholeskyBatched( } template -void magmaTrsm( +void magmaTriangularSolve( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) { - AT_ERROR("trtrs only takes float or double Tensors"); + AT_ERROR("triangular_solve only takes float or double Tensors"); } template -void magmaTrsmBatched( +void magmaTriangularSolveBatched( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { - AT_ERROR("trtrs only takes float or double Tensors"); + AT_ERROR("triangular_solve only takes float or double Tensors"); } template<> @@ -233,21 +233,21 @@ void magmaCholeskyBatched( } template<> -void magmaTrsm( +void magmaTriangularSolve( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) { magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); } template<> -void magmaTrsm( +void magmaTriangularSolve( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) { magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); } template<> -void magmaTrsmBatched( +void magmaTriangularSolveBatched( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { @@ -255,7 +255,7 @@ void magmaTrsmBatched( } template<> -void magmaTrsmBatched( +void magmaTriangularSolveBatched( magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { @@ -684,10 +684,10 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) { return triu_tril_cuda_template(result, self_c, k, "triu"); } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) { +static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) { #ifndef USE_MAGMA AT_ERROR("cholesky_solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); @@ -702,7 +702,7 @@ AT_ERROR("cholesky_solve: MAGMA library not found in " magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); if (b.dim() == 2) { - magmaTrsm(uplo, trans, diag, n, nrhs, A_data, n, b_data, n); + magmaTriangularSolve(uplo, trans, diag, n, nrhs, A_data, n, b_data, n); } else { auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); @@ -721,18 +721,19 @@ AT_ERROR("cholesky_solve: MAGMA library not found in " } MAGMAQueue magma_queue(b.get_device()); - magmaTrsmBatched( + magmaTriangularSolveBatched( uplo, trans, diag, n, nrhs, A_array, n, b_array, n, batch_size, magma_queue); } #endif } -std::tuple _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { +std::tuple _triangular_solve_helper_cuda(const Tensor& self, const Tensor& A, + bool upper, bool transpose, bool unitriangular) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{ - apply_trsm(self_working_copy, A_working_copy, upper, transpose, unitriangular); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{ + apply_triangular_solve(self_working_copy, A_working_copy, upper, transpose, unitriangular); }); return std::tuple(self_working_copy, A_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6a7a658..ca8a5ad 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3687,19 +3687,19 @@ matches_jit_signature: True variants: method, function -- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!)) +- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!)) matches_jit_signature: True -- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor) +- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor) matches_jit_signature: True variants: method, function -- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor) +- func: _triangular_solve_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor) matches_jit_signature: True variants: function dispatch: - CPU: _trtrs_helper_cpu - CUDA: _trsm_helper_cuda + CPU: _triangular_solve_helper_cpu + CUDA: _triangular_solve_helper_cuda - func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) matches_jit_signature: True diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index a7a3128..a98b278 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -401,6 +401,7 @@ view of a storage and defines numeric operations on it. .. automethod:: trace .. automethod:: transpose .. automethod:: transpose_ + .. automethod:: triangular_solve .. automethod:: tril .. automethod:: tril_ .. automethod:: triu diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 2bbeec9..782f58e 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -324,6 +324,7 @@ BLAS and LAPACK Operations .. autofunction:: solve .. autofunction:: svd .. autofunction:: symeig +.. autofunction:: triangular_solve .. autofunction:: trtrs Utilities diff --git a/test/test_autograd.py b/test/test_autograd.py index badeee4..aad83df 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2183,14 +2183,14 @@ class TestAutograd(TestCase): run_test(upper, dims) @skipIfNoLapack - def test_trtrs(self): + def test_triangular_solve(self): def _test_with_size(A_dims, B_dims): A = torch.rand(*A_dims).requires_grad_() b = torch.rand(*B_dims).requires_grad_() for upper, transpose, unitriangular in product((True, False), repeat=3): def func(A, b): - return torch.trtrs(b, A, upper, transpose, unitriangular) + return torch.triangular_solve(b, A, upper, transpose, unitriangular) gradcheck(func, [A, b]) gradgradcheck(func, [A, b]) diff --git a/test/test_cuda.py b/test/test_cuda.py index 59c50ce..a34e51f 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2564,16 +2564,16 @@ class TestCuda(TestCase): _TestTorchMixin._test_geqrf(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_trtrs(self): - _TestTorchMixin._test_trtrs(self, lambda t: t.cuda()) + def test_triangular_solve(self): + _TestTorchMixin._test_triangular_solve(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_trtrs_batched(self): - _TestTorchMixin._test_trtrs_batched(self, lambda t: t.cuda()) + def test_triangular_solve_batched(self): + _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_trtrs_batched_dims(self): - _TestTorchMixin._test_trtrs_batched_dims(self, lambda t: t.cuda()) + def test_triangular_solve_batched_dims(self): + _TestTorchMixin._test_triangular_solve_batched_dims(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_get_set_rng_state_all(self): diff --git a/test/test_torch.py b/test/test_torch.py index adcebb1..20563a6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4866,7 +4866,7 @@ class _TestTorchMixin(object): self._test_geqrf(self, lambda t: t) @staticmethod - def _test_trtrs(self, cast): + def _test_triangular_solve(self, cast): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), @@ -4883,54 +4883,54 @@ class _TestTorchMixin(object): L = torch.tril(a) # solve Ux = b - x = torch.trtrs(b, U)[0] + x = torch.triangular_solve(b, U)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) - x = torch.trtrs(b, U, True, False, False)[0] + x = torch.triangular_solve(b, U, True, False, False)[0] self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) # solve Lx = b - x = torch.trtrs(b, L, False)[0] + x = torch.triangular_solve(b, L, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) - x = torch.trtrs(b, L, False, False, False)[0] + x = torch.triangular_solve(b, L, False, False, False)[0] self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) # solve U'x = b - x = torch.trtrs(b, U, True, True)[0] + x = torch.triangular_solve(b, U, True, True)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) - x = torch.trtrs(b, U, True, True, False)[0] + x = torch.triangular_solve(b, U, True, True, False)[0] self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) # solve U'x = b by manual transposition - y = torch.trtrs(b, U.t(), False, False)[0] + y = torch.triangular_solve(b, U.t(), False, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # solve L'x = b - x = torch.trtrs(b, L, False, True)[0] + x = torch.triangular_solve(b, L, False, True)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) - x = torch.trtrs(b, L, False, True, False)[0] + x = torch.triangular_solve(b, L, False, True, False)[0] self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) # solve L'x = b by manual transposition - y = torch.trtrs(b, L.t(), True, False)[0] + y = torch.triangular_solve(b, L.t(), True, False)[0] self.assertLessEqual(x.dist(y), 1e-12) # test reuse - res1 = torch.trtrs(b, a)[0] + res1 = torch.triangular_solve(b, a)[0] ta = cast(torch.Tensor()) tb = cast(torch.Tensor()) - torch.trtrs(b, a, out=(tb, ta)) + torch.triangular_solve(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) tb.zero_() - torch.trtrs(b, a, out=(tb, ta)) + torch.triangular_solve(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) @skipIfNoLapack - def test_trtrs(self): - self._test_trtrs(self, lambda t: t) + def test_triangular_solve(self): + self._test_triangular_solve(self, lambda t: t) @staticmethod - def _test_trtrs_batched(self, cast): - def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular): + def _test_triangular_solve_batched(self, cast): + def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular): A = cast(torch.randn(*A_dims)) A = A.triu() if upper else A.tril() if unitriangular: @@ -4939,40 +4939,40 @@ class _TestTorchMixin(object): return A, b for upper, transpose, unitriangular in product([True, False], repeat=3): - # test against trtrs: one batch with all possible arguments - A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular) - x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0), - upper=upper, unitriangular=unitriangular, transpose=transpose)[0] - x = torch.trtrs(b, A, - upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + # test against triangular_solve: one batch with all possible arguments + A, b = triangular_solve_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular) + x_exp = torch.triangular_solve(b.squeeze(0), A.squeeze(0), + upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + x = torch.triangular_solve(b, A, + upper=upper, transpose=transpose, unitriangular=unitriangular)[0] self.assertEqual(x, x_exp.unsqueeze(0)) - # test against trtrs in a loop: four batches with all possible arguments - A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular) + # test against triangular_solve in a loop: four batches with all possible arguments + A, b = triangular_solve_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular) x_exp_list = [] for i in range(4): - x_exp = torch.trtrs(b[i], A[i], - upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + x_exp = torch.triangular_solve(b[i], A[i], + upper=upper, transpose=transpose, unitriangular=unitriangular)[0] x_exp_list.append(x_exp) x_exp = torch.stack(x_exp_list) - x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] self.assertEqual(x, x_exp) # basic correctness test - A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular) - x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + A, b = triangular_solve_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular) + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] if transpose: self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12) else: self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12) @skipIfNoLapack - def test_trtrs_batched(self): - _TestTorchMixin._test_trtrs_batched(self, lambda t: t) + def test_triangular_solve_batched(self): + _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t) @staticmethod - def _test_trtrs_batched_dims(self, cast): + def _test_triangular_solve_batched_dims(self, cast): if not TEST_SCIPY: return @@ -5000,7 +5000,7 @@ class _TestTorchMixin(object): x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(), upper, transpose, unitriangular)) A, b = cast(A), cast(b) - x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] self.assertEqual(x, cast(x_exp)) @@ -5012,8 +5012,8 @@ class _TestTorchMixin(object): run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular) # broadcasting A & b @skipIfNoLapack - def test_trtrs_batched_dims(self): - self._test_trtrs_batched_dims(self, lambda t: t) + def test_triangular_solve_batched_dims(self): + self._test_triangular_solve_batched_dims(self, lambda t: t) @skipIfNoLapack def test_gels(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index d99edff..b7acf82 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -850,15 +850,15 @@ - name: transpose_(Tensor self, int64_t dim0, int64_t dim1) self: grad.transpose(dim0, dim1) +- name: triangular_solve(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) + self, A: triangular_solve_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask) + - name: tril(Tensor self, int64_t diagonal) self: grad.tril(diagonal) - name: triu(Tensor self, int64_t diagonal) self: grad.triu(diagonal) -- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) - self, A: trtrs_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask) - - name: trunc(Tensor self) self: zeros_like(grad) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 43ab781..80f7495 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [ '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', 'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*', - '_cholesky.*', '_btrifact.*', + '_cholesky.*', '_btrifact.*', '_triangular_solve.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to', diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index e71d5fe..c927624 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -1739,14 +1739,14 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet, // Reference: // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf // Sec. 2.3.1 Matrix inverse product -std::tuple trtrs_backward( +std::tuple triangular_solve_backward( const Tensor & grad_x, const Tensor & grad_m, const Tensor & b, const Tensor & a, const Tensor & x, const bool upper, const bool transpose, const bool unitriangular, std::array output_mask) { Tensor grad_b, grad_a; if (grad_x.defined()) { - grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular)); + grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular)); if (output_mask[1]) { grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2)); if (upper) { diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 0d26f23..d58d83f 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2601,6 +2601,13 @@ transpose_(dim0, dim1) -> Tensor In-place version of :meth:`~Tensor.transpose` """) +add_docstr_all('triangular_solve', + r""" +triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + +See :func:`torch.triangular_solve` +""") + add_docstr_all('tril', r""" tril(k=0) -> Tensor @@ -2629,13 +2636,6 @@ triu_(k=0) -> Tensor In-place version of :meth:`~Tensor.triu` """) -add_docstr_all('trtrs', - r""" -trtrs(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) - -See :func:`torch.trtrs` -""") - add_docstr_all('trunc', r""" trunc() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index a088d18..0fd3d5e 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2224,6 +2224,22 @@ Example:: tensor(1.9073e-06) """) +add_docstr(torch.isnan, + r""" +Returns a new tensor with boolean elements representing if each element is `NaN` or not. + +Arguments: + tensor (Tensor): A tensor to check + +Returns: + Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. + +Example:: + + >>> torch.isnan(torch.tensor([1, float('nan'), 2])) + tensor([ 0, 1, 0], dtype=torch.uint8) +""") + add_docstr(torch.is_floating_point, r""" is_floating_point(tensor) -> (bool) @@ -5073,6 +5089,59 @@ Example:: [ 0.5809, 0.4942]]) """) +add_docstr(torch.triangular_solve, + r""" +triangular_solve(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) + +Solves a system of equations with a triangular coefficient matrix :math:`A` +and multiple right-hand sides :attr:`b`. + +In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular +with the default keyword arguments. + +`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are +batches of 2D matrices. If the inputs are batches, then returns +batched outputs `X` + +.. note:: + + The :attr:`out` keyword only supports 2D matrix inputs, that is, + `b, A` must be 2D matrices. + +Args: + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions + upper (bool, optional): whether to solve the upper-triangular system + of equations (default) or the lower-triangular system of equations. Default: ``True``. + transpose (bool, optional): whether :math:`A` should be transposed before + being sent into the solver. Default: ``False``. + unitriangular (bool, optional): whether :math:`A` is unit triangular. + If True, the diagonal elements of :math:`A` are assumed to be + 1 and not referenced from :math:`A`. Default: ``False``. + +Returns: + A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X` + is the solution to :math:`AX = b` (or whatever variant of the system of + equations, depending on the keyword arguments.) + +Examples:: + + >>> A = torch.randn(2, 2).triu() + >>> A + tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]]) + >>> b = torch.randn(2, 3) + >>> b + tensor([[-0.0210, 2.3513, -1.5492], + [ 1.5429, 0.7403, -1.0243]]) + >>> torch.triangular_solve(b, A) + (tensor([[ 1.7840, 2.9045, -2.5405], + [ 1.9319, 0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753], + [ 0.0000, 0.7986]])) +""") + add_docstr(torch.tril, r""" tril(input, diagonal=0, out=None) -> Tensor @@ -5291,59 +5360,6 @@ Example:: [1, 2, 2]]) """.format(**factory_common_args)) -add_docstr(torch.trtrs, - r""" -trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor) - -Solves a system of equations with a triangular coefficient matrix :math:`A` -and multiple right-hand sides :attr:`b`. - -In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular -with the default keyword arguments. - -`torch.trtrs(b, A)` can take in 2D inputs `b, A` or inputs that are -batches of 2D matrices. If the inputs are batches, then returns -batched outputs `X` - -.. note:: - - The :attr:`out` keyword only supports 2D matrix inputs, that is, - `b, A` must be 2D matrices. - -Args: - A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` - where :math:`*` is zero or more batch dimensions - b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where - :math:`*` is zero of more batch dimensions - upper (bool, optional): whether to solve the upper-triangular system - of equations (default) or the lower-triangular system of equations. Default: ``True``. - transpose (bool, optional): whether :math:`A` should be transposed before - being sent into the solver. Default: ``False``. - unitriangular (bool, optional): whether :math:`A` is unit triangular. - If True, the diagonal elements of :math:`A` are assumed to be - 1 and not referenced from :math:`A`. Default: ``False``. - -Returns: - A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X` - is the solution to :math:`AX = b` (or whatever variant of the system of - equations, depending on the keyword arguments.) - -Examples:: - - >>> A = torch.randn(2, 2).triu() - >>> A - tensor([[ 1.1527, -1.0753], - [ 0.0000, 0.7986]]) - >>> b = torch.randn(2, 3) - >>> b - tensor([[-0.0210, 2.3513, -1.5492], - [ 1.5429, 0.7403, -1.0243]]) - >>> torch.trtrs(b, A) - (tensor([[ 1.7840, 2.9045, -2.5405], - [ 1.9319, 0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753], - [ 0.0000, 0.7986]])) -""") - add_docstr(torch.trunc, r""" trunc(input, out=None) -> Tensor diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index b6a33dc..9b77d33 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -312,7 +312,7 @@ def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) / q._unbroadcasted_cov_diag.unsqueeze(-2)) - A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0] + A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0] term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) @@ -339,7 +339,7 @@ def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) / q._unbroadcasted_cov_diag.unsqueeze(-2)) - A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0] + A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0] term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) @@ -367,8 +367,8 @@ def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): (n, p.cov_factor.size(-1))) p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()) .expand(combined_batch_shape + (n, n))) - term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=False)[0]) - term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=False)[0]) + term21 = _batch_trace_XXT(torch.triangular_solve(p_cov_factor, q_scale_tril, upper=False)[0]) + term22 = _batch_trace_XXT(torch.triangular_solve(p_cov_diag, q_scale_tril, upper=False)[0]) term2 = term21 + term22 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) @@ -387,7 +387,7 @@ def _kl_multivariatenormal_multivariatenormal(p, q): n = p.event_shape[0] q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) - term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=False)[0]) + term2 = _batch_trace_XXT(torch.triangular_solve(p_scale_tril, q_scale_tril, upper=False)[0]) term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) return half_term1 + 0.5 * (term2 + term3 - n) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 6ba31eb..56822a0 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -162,7 +162,7 @@ class LowRankMultivariateNormal(Distribution): # where :math:`C` is the capacitance matrix. Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2) / self._unbroadcasted_cov_diag.unsqueeze(-2)) - A = torch.trtrs(Wt_Dinv, self._capacitance_tril, upper=False)[0] + A = torch.triangular_solve(Wt_Dinv, self._capacitance_tril, upper=False)[0] precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - torch.matmul(A.transpose(-1, -2), A)) return precision_matrix.expand(self._batch_shape + self._event_shape + diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 6efd8c9..0a38dc3 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -32,7 +32,7 @@ def _batch_mahalanobis(bL, bx): bx_batch_shape = bx.shape[:-1] # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), - # we are going to make bx have shape (..., 1, j, i, 1, n) to apply _batch_trtrs_lower + # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve bx_batch_dims = len(bx_batch_shape) bL_batch_dims = bL.dim() - 2 outer_batch_dims = bx_batch_dims - bL_batch_dims @@ -54,7 +54,7 @@ def _batch_mahalanobis(bL, bx): flat_L = bL.reshape(-1, n, n) # shape = b x n x n flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c - M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c + M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c M = M_swap.t() # shape = c x b # Now we revert the above reshape and permute operators. diff --git a/torch/functional.py b/torch/functional.py index 8a38e0a..580227c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1,13 +1,7 @@ import torch import torch.nn.functional as F from torch._six import inf -from torch._C import _add_docstr -from operator import mul -from functools import reduce -from collections import Iterable -from torch._utils import annotate from itertools import product -import math import warnings __all__ = [ @@ -17,7 +11,6 @@ __all__ = [ 'broadcast_tensors', 'isfinite', 'isinf', - 'isnan', 'norm', 'meshgrid', 'potrf', @@ -27,6 +20,7 @@ __all__ = [ 'split', 'stft', 'tensordot', + 'trtrs', 'unique', 'cartesian_prod', ] @@ -380,22 +374,6 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided) -isnan = _add_docstr(torch.isnan, r""" -Returns a new tensor with boolean elements representing if each element is `NaN` or not. - -Arguments: - tensor (Tensor): A tensor to check - -Returns: - Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. - -Example:: - - >>> torch.isnan(torch.tensor([1, float('nan'), 2])) - tensor([ 0, 1, 0], dtype=torch.uint8) -""") - - def unique(input, sorted=True, return_inverse=False, dim=None): r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. @@ -764,3 +742,19 @@ def gesv(b, A, out=None): warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the " "next release. Please use torch.solve instead.", stacklevel=2) return torch.solve(b, A, out=out) + + +def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None): + r"""Solves a system of equations with a triangular coefficient matrix :math:`A` + and multiple right-hand sides :attr:`b`. + + In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular + with the default keyword arguments. + + .. warning:: + :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be + removed in the next release. Please use :func:`torch.triangular_solve` instead. + """ + warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be " + "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2) + return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out) diff --git a/torch/tensor.py b/torch/tensor.py index 8838a71..bf239b3 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -279,6 +279,13 @@ class Tensor(torch._C._TensorBase): "next release. Please use torch.solve instead.", stacklevel=2) return super(Tensor, self).solve(A) + def trtrs(self, A, upper=True, transpose=False, unitriangular=False): + r"""See :func:`torch.triangular_solve`""" + warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be " + "removed in the next release. Please use torch.triangular_solve.", stacklevel=2) + return super(Tensor, self).triangular_solve(A, upper=upper, + transpose=transpose, unitriangular=unitriangular) + def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): r"""See :func:`torch.stft` -- 2.7.4