Rename trtrs to triangular_solve (#18213)
authorvishwakftw <cs15btech11043@iith.ac.in>
Thu, 21 Mar 2019 21:18:38 +0000 (14:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 21 Mar 2019 21:27:21 +0000 (14:27 -0700)
Summary:
Changelog:
- Renames `trtrs` to `triangular_solve` to remain consistent with `cholesky_solve` and `solve`.
- Rename all tests, fix callsites
- Create a tentative alias for `triangular_solve` under the name `trtrs`, and add a deprecation warning to not promote usage.
- Move `isnan` to _torch_docs.py
- Remove unnecessary imports
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18213

Differential Revision: D14566902

Pulled By: ezyang

fbshipit-source-id: 544f57c29477df391bacd5de700bed1add456d3f

22 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_autograd.py
test/test_cuda.py
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/gen_python_functions.py
tools/autograd/templates/Functions.cpp
torch/_tensor_docs.py
torch/_torch_docs.py
torch/distributions/kl.py
torch/distributions/lowrank_multivariate_normal.py
torch/distributions/multivariate_normal.py
torch/functional.py
torch/tensor.py

index 8ae46dd..24b64b9 100644 (file)
@@ -687,7 +687,7 @@ class CAFFE2_API Tensor {
   Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
   Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
   std::tuple<Tensor,Tensor> gels(const Tensor & A) const;
-  std::tuple<Tensor,Tensor> trtrs(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
+  std::tuple<Tensor,Tensor> triangular_solve(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
   std::tuple<Tensor,Tensor> symeig(bool eigenvectors=false, bool upper=true) const;
   std::tuple<Tensor,Tensor> eig(bool eigenvectors=false) const;
   std::tuple<Tensor,Tensor,Tensor> svd(bool some=true, bool compute_uv=true) const;
index 7d5beab..e153f06 100644 (file)
@@ -1132,8 +1132,8 @@ inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Sc
 inline std::tuple<Tensor,Tensor> Tensor::gels(const Tensor & A) const {
     return type().gels(*this, A);
 }
-inline std::tuple<Tensor,Tensor> Tensor::trtrs(const Tensor & A, bool upper, bool transpose, bool unitriangular) const {
-    return type().trtrs(*this, A, upper, transpose, unitriangular);
+inline std::tuple<Tensor,Tensor> Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const {
+    return type().triangular_solve(*this, A, upper, transpose, unitriangular);
 }
 inline std::tuple<Tensor,Tensor> Tensor::symeig(bool eigenvectors, bool upper) const {
     return type().symeig(*this, eigenvectors, upper);
index 37a784a..5b7a5d7 100644 (file)
@@ -565,7 +565,7 @@ struct CAFFE2_API Type {
   virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
   virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
   virtual std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) const = 0;
-  virtual std::tuple<Tensor,Tensor> trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0;
+  virtual std::tuple<Tensor,Tensor> triangular_solve(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0;
   virtual std::tuple<Tensor,Tensor> symeig(const Tensor & self, bool eigenvectors, bool upper) const = 0;
   virtual std::tuple<Tensor,Tensor> eig(const Tensor & self, bool eigenvectors) const = 0;
   virtual std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool compute_uv) const = 0;
index ba6b64d..b71ffe8 100644 (file)
@@ -671,10 +671,10 @@ _(aten, to_dense) \
 _(aten, topk) \
 _(aten, trace) \
 _(aten, transpose) \
+_(aten, triangular_solve) \
 _(aten, tril) \
 _(aten, triplet_margin_loss) \
 _(aten, triu) \
-_(aten, trtrs) \
 _(aten, trunc) \
 _(aten, type_as) \
 _(aten, unbind) \
index 562e154..46f6d47 100644 (file)
@@ -71,8 +71,8 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) {
 }
 
 template<class scalar_t>
-void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
-  AT_ERROR("trtrs only takes float or double Tensors");
+void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
+  AT_ERROR("triangular_solve only takes float or double Tensors");
 } 
 
 #ifdef USE_LAPACK
@@ -116,11 +116,11 @@ template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *
   spotrf_(&uplo, &n, a, &lda, info);
 }
 
-template<> void lapackTrtrs<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
+template<> void lapackTriangularSolve<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
   dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
 }
 
-template<> void lapackTrtrs<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
+template<> void lapackTriangularSolve<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
   strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
 }
 #endif
@@ -333,9 +333,6 @@ Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
 }
 
 Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
-  AT_CHECK(self.dim() == 2 && A.dim() == 2,
-           "torch.cholesky_solve() with the `out` keyword does not support batching. "
-           "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
   Tensor result_tmp;
   result_tmp = at::_cholesky_solve_helper(self, A, upper);
   result.resize_as_(result_tmp).copy_(result_tmp);
@@ -641,12 +638,12 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
   return result;
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template<typename scalar_t>
-static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector<int64_t>& infos) {
+static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
 #ifndef USE_LAPACK
-  AT_ERROR("trtrs: LAPACK library not found in compilation");
+  AT_ERROR("triangular_solve: LAPACK library not found in compilation");
 #else
   char uplo = upper ? 'U' : 'L';
   char trans = transpose ? 'T' : 'N';
@@ -659,8 +656,7 @@ static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool u
 
   int info;
   if (b.dim() == 2) {
-    lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
-    infos[0] = info;
+    lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
   } else {
     auto A_mat_stride = matrixStride(A);
     auto b_mat_stride = matrixStride(b);
@@ -668,49 +664,38 @@ static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool u
     for (int64_t i = 0; i < batch_size; i++) {
       scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
       scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
-      lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
-      infos[i] = info;
-      if (info != 0) {
-        return;
-      }
+      lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
     }
   }
 #endif
 }
 
-std::tuple<Tensor, Tensor> _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> _triangular_solve_helper_cpu(const Tensor& self, const Tensor& A,
+                                                        bool upper, bool transpose, bool unitriangular) {
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
-  std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{
-    apply_trtrs<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos);
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cpu", [&]{
+    apply_triangular_solve<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
   });
-  if (self.dim() > 2) {
-    batchCheckErrors(infos, "trtrs_cpu");
-  } else {
-    singleCheckErrors(infos[0], "trtrs_cpu");
-  }
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
 
 // Supports arbitrary batch dimensions for self and A
-std::tuple<Tensor, Tensor> trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> triangular_solve(const Tensor& self, const Tensor& A,
+                                            bool upper, bool transpose, bool unitriangular) {
   AT_CHECK(self.dim() >= 2,
            "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
   AT_CHECK(A.dim() >= 2,
            "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
   Tensor self_broadcasted, A_broadcasted;
   std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
-  return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
+  return at::_triangular_solve_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
 }
 
-std::tuple<Tensor&, Tensor&> trtrs_out(Tensor& result, Tensor& clone_A,
-                                       const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
-  AT_CHECK(self.dim() == 2 && A.dim() == 2,
-           "torch.trtrs() with the `out` keyword does not support batching. "
-           "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
+std::tuple<Tensor&, Tensor&> triangular_solve_out(Tensor& result, Tensor& clone_A, const Tensor& self, const Tensor& A,
+                                                  bool upper, bool transpose, bool unitriangular) {
   Tensor result_tmp, clone_A_tmp;
-  std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular);
+  std::tie(result_tmp, clone_A_tmp) = at::_triangular_solve_helper(self, A, upper, transpose, unitriangular);
   result.resize_as_(result_tmp).copy_(result_tmp);
   clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp);
   return std::tuple<Tensor&, Tensor&>(result, clone_A);
index 812f1ea..f401995 100644 (file)
@@ -86,18 +86,18 @@ void magmaCholeskyBatched(
 }
 
 template<class scalar_t>
-void magmaTrsm(
+void magmaTriangularSolve(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) {
-  AT_ERROR("trtrs only takes float or double Tensors");
+  AT_ERROR("triangular_solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
-void magmaTrsmBatched(
+void magmaTriangularSolveBatched(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
-  AT_ERROR("trtrs only takes float or double Tensors");
+  AT_ERROR("triangular_solve only takes float or double Tensors");
 }
 
 template<>
@@ -233,21 +233,21 @@ void magmaCholeskyBatched<float>(
 }
 
 template<>
-void magmaTrsm<double>(
+void magmaTriangularSolve<double>(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) {
   magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
 }
 
 template<>
-void magmaTrsm<float>(
+void magmaTriangularSolve<float>(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) {
   magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
 }
 
 template<>
-void magmaTrsmBatched<double>(
+void magmaTriangularSolveBatched<double>(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
@@ -255,7 +255,7 @@ void magmaTrsmBatched<double>(
 }
 
 template<>
-void magmaTrsmBatched<float>(
+void magmaTriangularSolveBatched<float>(
     magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
@@ -684,10 +684,10 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
   return triu_tril_cuda_template<true>(result, self_c, k, "triu");
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
+static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
 #ifndef USE_MAGMA
 AT_ERROR("cholesky_solve: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
@@ -702,7 +702,7 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
   magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
 
   if (b.dim() == 2) {
-    magmaTrsm<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
+    magmaTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
   } else {
     auto A_mat_stride = matrixStride(A);
     auto b_mat_stride = matrixStride(b);
@@ -721,18 +721,19 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
     }
 
     MAGMAQueue magma_queue(b.get_device());
-    magmaTrsmBatched<scalar_t>(
+    magmaTriangularSolveBatched<scalar_t>(
         uplo, trans, diag, n, nrhs, A_array, n,
         b_array, n, batch_size, magma_queue);
   }
 #endif
 }
 
-std::tuple<Tensor, Tensor> _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> _triangular_solve_helper_cuda(const Tensor& self, const Tensor& A,
+                                                         bool upper, bool transpose, bool unitriangular) {
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
-  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{
-    apply_trsm<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{
+    apply_triangular_solve<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
   });
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
index 6a7a658..ca8a5ad 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
-- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!))
+- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!))
   matches_jit_signature: True
 
-- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor)
+- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: method, function
 
-- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
+- func: _triangular_solve_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
   matches_jit_signature: True
   variants: function
   dispatch:
-    CPU: _trtrs_helper_cpu
-    CUDA: _trsm_helper_cuda
+    CPU: _triangular_solve_helper_cpu
+    CUDA: _triangular_solve_helper_cuda
 
 - func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
   matches_jit_signature: True
index a7a3128..a98b278 100644 (file)
@@ -401,6 +401,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: trace
    .. automethod:: transpose
    .. automethod:: transpose_
+   .. automethod:: triangular_solve
    .. automethod:: tril
    .. automethod:: tril_
    .. automethod:: triu
index 2bbeec9..782f58e 100644 (file)
@@ -324,6 +324,7 @@ BLAS and LAPACK Operations
 .. autofunction:: solve
 .. autofunction:: svd
 .. autofunction:: symeig
+.. autofunction:: triangular_solve
 .. autofunction:: trtrs
 
 Utilities
index badeee4..aad83df 100644 (file)
@@ -2183,14 +2183,14 @@ class TestAutograd(TestCase):
             run_test(upper, dims)
 
     @skipIfNoLapack
-    def test_trtrs(self):
+    def test_triangular_solve(self):
         def _test_with_size(A_dims, B_dims):
             A = torch.rand(*A_dims).requires_grad_()
             b = torch.rand(*B_dims).requires_grad_()
 
             for upper, transpose, unitriangular in product((True, False), repeat=3):
                 def func(A, b):
-                    return torch.trtrs(b, A, upper, transpose, unitriangular)
+                    return torch.triangular_solve(b, A, upper, transpose, unitriangular)
 
                 gradcheck(func, [A, b])
                 gradgradcheck(func, [A, b])
index 59c50ce..a34e51f 100644 (file)
@@ -2564,16 +2564,16 @@ class TestCuda(TestCase):
         _TestTorchMixin._test_geqrf(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_trtrs(self):
-        _TestTorchMixin._test_trtrs(self, lambda t: t.cuda())
+    def test_triangular_solve(self):
+        _TestTorchMixin._test_triangular_solve(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_trtrs_batched(self):
-        _TestTorchMixin._test_trtrs_batched(self, lambda t: t.cuda())
+    def test_triangular_solve_batched(self):
+        _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_trtrs_batched_dims(self):
-        _TestTorchMixin._test_trtrs_batched_dims(self, lambda t: t.cuda())
+    def test_triangular_solve_batched_dims(self):
+        _TestTorchMixin._test_triangular_solve_batched_dims(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
     def test_get_set_rng_state_all(self):
index adcebb1..20563a6 100644 (file)
@@ -4866,7 +4866,7 @@ class _TestTorchMixin(object):
         self._test_geqrf(self, lambda t: t)
 
     @staticmethod
-    def _test_trtrs(self, cast):
+    def _test_triangular_solve(self, cast):
         a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
                           (-6.05, -3.30, 5.36, -4.44, 1.08),
                           (-0.45, 2.58, -2.70, 0.27, 9.04),
@@ -4883,54 +4883,54 @@ class _TestTorchMixin(object):
         L = torch.tril(a)
 
         # solve Ux = b
-        x = torch.trtrs(b, U)[0]
+        x = torch.triangular_solve(b, U)[0]
         self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
-        x = torch.trtrs(b, U, True, False, False)[0]
+        x = torch.triangular_solve(b, U, True, False, False)[0]
         self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
 
         # solve Lx = b
-        x = torch.trtrs(b, L, False)[0]
+        x = torch.triangular_solve(b, L, False)[0]
         self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
-        x = torch.trtrs(b, L, False, False, False)[0]
+        x = torch.triangular_solve(b, L, False, False, False)[0]
         self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
 
         # solve U'x = b
-        x = torch.trtrs(b, U, True, True)[0]
+        x = torch.triangular_solve(b, U, True, True)[0]
         self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
-        x = torch.trtrs(b, U, True, True, False)[0]
+        x = torch.triangular_solve(b, U, True, True, False)[0]
         self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
 
         # solve U'x = b by manual transposition
-        y = torch.trtrs(b, U.t(), False, False)[0]
+        y = torch.triangular_solve(b, U.t(), False, False)[0]
         self.assertLessEqual(x.dist(y), 1e-12)
 
         # solve L'x = b
-        x = torch.trtrs(b, L, False, True)[0]
+        x = torch.triangular_solve(b, L, False, True)[0]
         self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
-        x = torch.trtrs(b, L, False, True, False)[0]
+        x = torch.triangular_solve(b, L, False, True, False)[0]
         self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
 
         # solve L'x = b by manual transposition
-        y = torch.trtrs(b, L.t(), True, False)[0]
+        y = torch.triangular_solve(b, L.t(), True, False)[0]
         self.assertLessEqual(x.dist(y), 1e-12)
 
         # test reuse
-        res1 = torch.trtrs(b, a)[0]
+        res1 = torch.triangular_solve(b, a)[0]
         ta = cast(torch.Tensor())
         tb = cast(torch.Tensor())
-        torch.trtrs(b, a, out=(tb, ta))
+        torch.triangular_solve(b, a, out=(tb, ta))
         self.assertEqual(res1, tb, 0)
         tb.zero_()
-        torch.trtrs(b, a, out=(tb, ta))
+        torch.triangular_solve(b, a, out=(tb, ta))
         self.assertEqual(res1, tb, 0)
 
     @skipIfNoLapack
-    def test_trtrs(self):
-        self._test_trtrs(self, lambda t: t)
+    def test_triangular_solve(self):
+        self._test_triangular_solve(self, lambda t: t)
 
     @staticmethod
-    def _test_trtrs_batched(self, cast):
-        def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular):
+    def _test_triangular_solve_batched(self, cast):
+        def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular):
             A = cast(torch.randn(*A_dims))
             A = A.triu() if upper else A.tril()
             if unitriangular:
@@ -4939,40 +4939,40 @@ class _TestTorchMixin(object):
             return A, b
 
         for upper, transpose, unitriangular in product([True, False], repeat=3):
-            # test against trtrs: one batch with all possible arguments
-            A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
-            x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0),
-                                upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
-            x = torch.trtrs(b, A,
-                            upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+            # test against triangular_solve: one batch with all possible arguments
+            A, b = triangular_solve_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
+            x_exp = torch.triangular_solve(b.squeeze(0), A.squeeze(0),
+                                           upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
+            x = torch.triangular_solve(b, A,
+                                       upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
             self.assertEqual(x, x_exp.unsqueeze(0))
 
-            # test against trtrs in a loop: four batches with all possible arguments
-            A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
+            # test against triangular_solve in a loop: four batches with all possible arguments
+            A, b = triangular_solve_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
             x_exp_list = []
             for i in range(4):
-                x_exp = torch.trtrs(b[i], A[i],
-                                    upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+                x_exp = torch.triangular_solve(b[i], A[i],
+                                               upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
                 x_exp_list.append(x_exp)
             x_exp = torch.stack(x_exp_list)
 
-            x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
             self.assertEqual(x, x_exp)
 
             # basic correctness test
-            A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
-            x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+            A, b = triangular_solve_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
+            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
             if transpose:
                 self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12)
             else:
                 self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12)
 
     @skipIfNoLapack
-    def test_trtrs_batched(self):
-        _TestTorchMixin._test_trtrs_batched(self, lambda t: t)
+    def test_triangular_solve_batched(self):
+        _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t)
 
     @staticmethod
-    def _test_trtrs_batched_dims(self, cast):
+    def _test_triangular_solve_batched_dims(self, cast):
         if not TEST_SCIPY:
             return
 
@@ -5000,7 +5000,7 @@ class _TestTorchMixin(object):
             x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(),
                                                          upper, transpose, unitriangular))
             A, b = cast(A), cast(b)
-            x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
+            x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
 
             self.assertEqual(x, cast(x_exp))
 
@@ -5012,8 +5012,8 @@ class _TestTorchMixin(object):
             run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular)  # broadcasting A & b
 
     @skipIfNoLapack
-    def test_trtrs_batched_dims(self):
-        self._test_trtrs_batched_dims(self, lambda t: t)
+    def test_triangular_solve_batched_dims(self):
+        self._test_triangular_solve_batched_dims(self, lambda t: t)
 
     @skipIfNoLapack
     def test_gels(self):
index d99edff..b7acf82 100644 (file)
 - name: transpose_(Tensor self, int64_t dim0, int64_t dim1)
   self: grad.transpose(dim0, dim1)
 
+- name: triangular_solve(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
+  self, A: triangular_solve_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask)
+
 - name: tril(Tensor self, int64_t diagonal)
   self: grad.tril(diagonal)
 
 - name: triu(Tensor self, int64_t diagonal)
   self: grad.triu(diagonal)
 
-- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
-  self, A: trtrs_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask)
-
 - name: trunc(Tensor self)
   self: zeros_like(grad)
 
index 43ab781..80f7495 100644 (file)
@@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
     'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
-    '_cholesky.*', '_btrifact.*',
+    '_cholesky.*', '_btrifact.*', '_triangular_solve.*',
     'slice', 'randint(_out)?',
     'item', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
index e71d5fe..c927624 100644 (file)
@@ -1739,14 +1739,14 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet,
 // Reference:
 // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
 // Sec. 2.3.1 Matrix inverse product
-std::tuple<Tensor, Tensor> trtrs_backward(
+std::tuple<Tensor, Tensor> triangular_solve_backward(
     const Tensor & grad_x, const Tensor & grad_m,
     const Tensor & b, const Tensor & a, const Tensor & x,
     const bool upper, const bool transpose, const bool unitriangular,
     std::array<bool, 2> output_mask) {
   Tensor grad_b, grad_a;
   if (grad_x.defined()) {
-    grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular));
+    grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
     if (output_mask[1]) {
       grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
       if (upper) {
index 0d26f23..d58d83f 100644 (file)
@@ -2601,6 +2601,13 @@ transpose_(dim0, dim1) -> Tensor
 In-place version of :meth:`~Tensor.transpose`
 """)
 
+add_docstr_all('triangular_solve',
+               r"""
+triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
+
+See :func:`torch.triangular_solve`
+""")
+
 add_docstr_all('tril',
                r"""
 tril(k=0) -> Tensor
@@ -2629,13 +2636,6 @@ triu_(k=0) -> Tensor
 In-place version of :meth:`~Tensor.triu`
 """)
 
-add_docstr_all('trtrs',
-               r"""
-trtrs(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
-
-See :func:`torch.trtrs`
-""")
-
 add_docstr_all('trunc',
                r"""
 trunc() -> Tensor
index a088d18..0fd3d5e 100644 (file)
@@ -2224,6 +2224,22 @@ Example::
     tensor(1.9073e-06)
 """)
 
+add_docstr(torch.isnan,
+           r"""
+Returns a new tensor with boolean elements representing if each element is `NaN` or not.
+
+Arguments:
+    tensor (Tensor): A tensor to check
+
+Returns:
+    Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
+
+Example::
+
+    >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
+    tensor([ 0,  1,  0], dtype=torch.uint8)
+""")
+
 add_docstr(torch.is_floating_point,
            r"""
 is_floating_point(tensor) -> (bool)
@@ -5073,6 +5089,59 @@ Example::
             [ 0.5809,  0.4942]])
 """)
 
+add_docstr(torch.triangular_solve,
+           r"""
+triangular_solve(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
+
+Solves a system of equations with a triangular coefficient matrix :math:`A`
+and multiple right-hand sides :attr:`b`.
+
+In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
+with the default keyword arguments.
+
+`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are
+batches of 2D matrices. If the inputs are batches, then returns
+batched outputs `X`
+
+.. note::
+
+    The :attr:`out` keyword only supports 2D matrix inputs, that is,
+    `b, A` must be 2D matrices.
+
+Args:
+    A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
+                where :math:`*` is zero or more batch dimensions
+    b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where
+                :math:`*` is zero of more batch dimensions
+    upper (bool, optional): whether to solve the upper-triangular system
+        of equations (default) or the lower-triangular system of equations. Default: ``True``.
+    transpose (bool, optional): whether :math:`A` should be transposed before
+        being sent into the solver. Default: ``False``.
+    unitriangular (bool, optional): whether :math:`A` is unit triangular.
+        If True, the diagonal elements of :math:`A` are assumed to be
+        1 and not referenced from :math:`A`. Default: ``False``.
+
+Returns:
+    A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X`
+    is the solution to :math:`AX = b` (or whatever variant of the system of
+    equations, depending on the keyword arguments.)
+
+Examples::
+
+    >>> A = torch.randn(2, 2).triu()
+    >>> A
+    tensor([[ 1.1527, -1.0753],
+            [ 0.0000,  0.7986]])
+    >>> b = torch.randn(2, 3)
+    >>> b
+    tensor([[-0.0210,  2.3513, -1.5492],
+            [ 1.5429,  0.7403, -1.0243]])
+    >>> torch.triangular_solve(b, A)
+    (tensor([[ 1.7840,  2.9045, -2.5405],
+            [ 1.9319,  0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753],
+            [ 0.0000,  0.7986]]))
+""")
+
 add_docstr(torch.tril,
            r"""
 tril(input, diagonal=0, out=None) -> Tensor
@@ -5291,59 +5360,6 @@ Example::
             [1, 2, 2]])
 """.format(**factory_common_args))
 
-add_docstr(torch.trtrs,
-           r"""
-trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
-
-Solves a system of equations with a triangular coefficient matrix :math:`A`
-and multiple right-hand sides :attr:`b`.
-
-In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
-with the default keyword arguments.
-
-`torch.trtrs(b, A)` can take in 2D inputs `b, A` or inputs that are
-batches of 2D matrices. If the inputs are batches, then returns
-batched outputs `X`
-
-.. note::
-
-    The :attr:`out` keyword only supports 2D matrix inputs, that is,
-    `b, A` must be 2D matrices.
-
-Args:
-    A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
-                where :math:`*` is zero or more batch dimensions
-    b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where
-                :math:`*` is zero of more batch dimensions
-    upper (bool, optional): whether to solve the upper-triangular system
-        of equations (default) or the lower-triangular system of equations. Default: ``True``.
-    transpose (bool, optional): whether :math:`A` should be transposed before
-        being sent into the solver. Default: ``False``.
-    unitriangular (bool, optional): whether :math:`A` is unit triangular.
-        If True, the diagonal elements of :math:`A` are assumed to be
-        1 and not referenced from :math:`A`. Default: ``False``.
-
-Returns:
-    A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X`
-    is the solution to :math:`AX = b` (or whatever variant of the system of
-    equations, depending on the keyword arguments.)
-
-Examples::
-
-    >>> A = torch.randn(2, 2).triu()
-    >>> A
-    tensor([[ 1.1527, -1.0753],
-            [ 0.0000,  0.7986]])
-    >>> b = torch.randn(2, 3)
-    >>> b
-    tensor([[-0.0210,  2.3513, -1.5492],
-            [ 1.5429,  0.7403, -1.0243]])
-    >>> torch.trtrs(b, A)
-    (tensor([[ 1.7840,  2.9045, -2.5405],
-            [ 1.9319,  0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753],
-            [ 0.0000,  0.7986]]))
-""")
-
 add_docstr(torch.trunc,
            r"""
 trunc(input, out=None) -> Tensor
index b6a33dc..9b77d33 100644 (file)
@@ -312,7 +312,7 @@ def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
     #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
     qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
                  q._unbroadcasted_cov_diag.unsqueeze(-2))
-    A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+    A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
     term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
     term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
                               q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
@@ -339,7 +339,7 @@ def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
     #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
     qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
                  q._unbroadcasted_cov_diag.unsqueeze(-2))
-    A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+    A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
     term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
                               q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
     term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
@@ -367,8 +367,8 @@ def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
                                                       (n, p.cov_factor.size(-1)))
     p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt())
                   .expand(combined_batch_shape + (n, n)))
-    term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=False)[0])
-    term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=False)[0])
+    term21 = _batch_trace_XXT(torch.triangular_solve(p_cov_factor, q_scale_tril, upper=False)[0])
+    term22 = _batch_trace_XXT(torch.triangular_solve(p_cov_diag, q_scale_tril, upper=False)[0])
     term2 = term21 + term22
     return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
 
@@ -387,7 +387,7 @@ def _kl_multivariatenormal_multivariatenormal(p, q):
     n = p.event_shape[0]
     q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
     p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
-    term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=False)[0])
+    term2 = _batch_trace_XXT(torch.triangular_solve(p_scale_tril, q_scale_tril, upper=False)[0])
     term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
     return half_term1 + 0.5 * (term2 + term3 - n)
 
index 6ba31eb..56822a0 100644 (file)
@@ -162,7 +162,7 @@ class LowRankMultivariateNormal(Distribution):
         # where :math:`C` is the capacitance matrix.
         Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
                    / self._unbroadcasted_cov_diag.unsqueeze(-2))
-        A = torch.trtrs(Wt_Dinv, self._capacitance_tril, upper=False)[0]
+        A = torch.triangular_solve(Wt_Dinv, self._capacitance_tril, upper=False)[0]
         precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())
                             - torch.matmul(A.transpose(-1, -2), A))
         return precision_matrix.expand(self._batch_shape + self._event_shape +
index 6efd8c9..0a38dc3 100644 (file)
@@ -32,7 +32,7 @@ def _batch_mahalanobis(bL, bx):
     bx_batch_shape = bx.shape[:-1]
 
     # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
-    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply _batch_trtrs_lower
+    # we are going to make bx have shape (..., 1, j,  i, 1, n) to apply batched tri.solve
     bx_batch_dims = len(bx_batch_shape)
     bL_batch_dims = bL.dim() - 2
     outer_batch_dims = bx_batch_dims - bL_batch_dims
@@ -54,7 +54,7 @@ def _batch_mahalanobis(bL, bx):
     flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
     flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
     flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
-    M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2)  # shape = b x c
+    M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2)  # shape = b x c
     M = M_swap.t()  # shape = c x b
 
     # Now we revert the above reshape and permute operators.
index 8a38e0a..580227c 100644 (file)
@@ -1,13 +1,7 @@
 import torch
 import torch.nn.functional as F
 from torch._six import inf
-from torch._C import _add_docstr
-from operator import mul
-from functools import reduce
-from collections import Iterable
-from torch._utils import annotate
 from itertools import product
-import math
 import warnings
 
 __all__ = [
@@ -17,7 +11,6 @@ __all__ = [
     'broadcast_tensors',
     'isfinite',
     'isinf',
-    'isnan',
     'norm',
     'meshgrid',
     'potrf',
@@ -27,6 +20,7 @@ __all__ = [
     'split',
     'stft',
     'tensordot',
+    'trtrs',
     'unique',
     'cartesian_prod',
 ]
@@ -380,22 +374,6 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None,
     return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
 
 
-isnan = _add_docstr(torch.isnan, r"""
-Returns a new tensor with boolean elements representing if each element is `NaN` or not.
-
-Arguments:
-    tensor (Tensor): A tensor to check
-
-Returns:
-    Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
-
-Example::
-
-    >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
-    tensor([ 0,  1,  0], dtype=torch.uint8)
-""")
-
-
 def unique(input, sorted=True, return_inverse=False, dim=None):
     r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
 
@@ -764,3 +742,19 @@ def gesv(b, A, out=None):
     warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
                   "next release. Please use torch.solve instead.", stacklevel=2)
     return torch.solve(b, A, out=out)
+
+
+def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
+    r"""Solves a system of equations with a triangular coefficient matrix :math:`A`
+    and multiple right-hand sides :attr:`b`.
+
+    In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
+    with the default keyword arguments.
+
+    .. warning::
+        :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
+        removed in the next release. Please use :func:`torch.triangular_solve` instead.
+    """
+    warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+                  "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
+    return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
index 8838a71..bf239b3 100644 (file)
@@ -279,6 +279,13 @@ class Tensor(torch._C._TensorBase):
                       "next release. Please use torch.solve instead.", stacklevel=2)
         return super(Tensor, self).solve(A)
 
+    def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
+        r"""See :func:`torch.triangular_solve`"""
+        warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+                      "removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
+        return super(Tensor, self).triangular_solve(A, upper=upper,
+                                                    transpose=transpose, unitriangular=unitriangular)
+
     def stft(self, n_fft, hop_length=None, win_length=None, window=None,
              center=True, pad_mode='reflect', normalized=False, onesided=True):
         r"""See :func:`torch.stft`