Rename potrs to cholesky_solve (#15334)
authorvishwakftw <cs15btech11043@iith.ac.in>
Wed, 19 Dec 2018 20:11:49 +0000 (12:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 20:31:24 +0000 (12:31 -0800)
Summary:
Changelog:
- Renames `potrs` to `cholesky_solve` to remain consistent with Tensorflow and Scipy (not really, they call their function chol_solve)
- Default argument for upper in cholesky_solve is False. This will allow a seamless interface between `cholesky` and `cholesky_solve`, since the `upper` argument in both function are the same.
- Rename all tests
- Create a tentative alias for `cholesky_solve` under the name `potrs`, and add deprecated warning to not promote usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15334

Differential Revision: D13507724

Pulled By: soumith

fbshipit-source-id: b826996541e49d2e2bcd061b72a38c39450c76d0

17 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_cuda.py
test/test_torch.py
tools/autograd/derivatives.yaml
torch/_tensor_docs.py
torch/_torch_docs.py
torch/functional.py
torch/tensor.py

index 2a2422c..c8d03e0 100644 (file)
@@ -634,7 +634,7 @@ public:
   std::tuple<Tensor,Tensor> eig(bool eigenvectors=false) const;
   std::tuple<Tensor,Tensor,Tensor> svd(bool some=true, bool compute_uv=true) const;
   Tensor cholesky(bool upper=false) const;
-  Tensor potrs(const Tensor & input2, bool upper=true) const;
+  Tensor cholesky_solve(const Tensor & input2, bool upper=false) const;
   Tensor potri(bool upper=true) const;
   std::tuple<Tensor,Tensor> pstrf(bool upper=true, Scalar tol=-1) const;
   std::tuple<Tensor,Tensor> qr() const;
index a107111..17b74bc 100644 (file)
@@ -1123,8 +1123,8 @@ inline std::tuple<Tensor,Tensor,Tensor> Tensor::svd(bool some, bool compute_uv)
 inline Tensor Tensor::cholesky(bool upper) const {
     return type().cholesky(*this, upper);
 }
-inline Tensor Tensor::potrs(const Tensor & input2, bool upper) const {
-    return type().potrs(*this, input2, upper);
+inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const {
+    return type().cholesky_solve(*this, input2, upper);
 }
 inline Tensor Tensor::potri(bool upper) const {
     return type().potri(*this, upper);
index ea57210..55667e0 100644 (file)
@@ -541,7 +541,7 @@ struct CAFFE2_API Type {
   virtual std::tuple<Tensor,Tensor> eig(const Tensor & self, bool eigenvectors) const = 0;
   virtual std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool compute_uv) const = 0;
   virtual Tensor cholesky(const Tensor & self, bool upper) const = 0;
-  virtual Tensor potrs(const Tensor & self, const Tensor & input2, bool upper) const = 0;
+  virtual Tensor cholesky_solve(const Tensor & self, const Tensor & input2, bool upper) const = 0;
   virtual Tensor potri(const Tensor & self, bool upper) const = 0;
   virtual std::tuple<Tensor,Tensor> pstrf(const Tensor & self, bool upper, Scalar tol) const = 0;
   virtual std::tuple<Tensor,Tensor> qr(const Tensor & self) const = 0;
index 4500bfa..72f07ce 100644 (file)
@@ -43,6 +43,7 @@ _(aten, _cast_Short) \
 _(aten, _cat) \
 _(aten, _ceil) \
 _(aten, _cholesky_helper) \
+_(aten, _cholesky_solve_helper) \
 _(aten, _convolution) \
 _(aten, _convolution_double_backward) \
 _(aten, _convolution_nogroup) \
@@ -102,7 +103,6 @@ _(aten, _pack_padded_sequence_backward) \
 _(aten, _pad_packed_sequence) \
 _(aten, _pdist_backward) \
 _(aten, _pdist_forward) \
-_(aten, _potrs_helper) \
 _(aten, _prod) \
 _(aten, _prodall) \
 _(aten, _range) \
@@ -242,6 +242,7 @@ _(aten, ceil) \
 _(aten, celu) \
 _(aten, chain_matmul) \
 _(aten, cholesky) \
+_(aten, cholesky_solve) \
 _(aten, chunk) \
 _(aten, clamp) \
 _(aten, clamp_max) \
@@ -523,7 +524,6 @@ _(aten, pixel_shuffle) \
 _(aten, poisson) \
 _(aten, polygamma) \
 _(aten, potri) \
-_(aten, potrs) \
 _(aten, pow) \
 _(aten, prelu) \
 _(aten, prelu_backward) \
index f935fc1..42b800e 100644 (file)
@@ -55,8 +55,8 @@ void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwo
 }
 
 template<class scalar_t>
-void lapackPotrs(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
-  AT_ERROR("potrs only takes float or double Tensors");
+void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
+  AT_ERROR("cholesky_solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -89,11 +89,11 @@ template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, i
   sgetrf_(&m, &n, a, &lda, ipiv, info);
 }
 
-template<> void lapackPotrs<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
+template<> void lapackCholeskySolve<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
   dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
 }
 
-template<> void lapackPotrs<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
+template<> void lapackCholeskySolve<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
   spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
 }
 
@@ -245,12 +245,12 @@ Tensor& inverse_out(Tensor &result, const Tensor &self) {
   return result;
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template<typename scalar_t>
-static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>& infos) {
+static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>& infos) {
 #ifndef USE_LAPACK
-  AT_ERROR("potrs: LAPACK library not found in compilation");
+  AT_ERROR("cholesky_solve: LAPACK library not found in compilation");
 #else
   char uplo = upper ? 'U' : 'L';
 
@@ -267,7 +267,7 @@ static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>&
     int info;
     scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
     scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
-    lapackPotrs<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
+    lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
     infos[i] = info;
     if (info != 0) {
       return;
@@ -276,31 +276,31 @@ static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>&
 #endif
 }
 
-Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
+Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
   std::vector<int64_t> infos(batchCount(self), 0);
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{
-    apply_potrs<scalar_t>(self_working_copy, A_working_copy, upper, infos);
+  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
+    apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
   });
-  batchCheckErrors(infos, "potrs");
+  batchCheckErrors(infos, "cholesky_solve");
   return self_working_copy;
 }
 
 // Supports arbitrary batch dimensions for self and A
-Tensor potrs(const Tensor& self, const Tensor& A, bool upper) {
+Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
   if (self.dim() <= 2 && A.dim() <= 2) {
     return at::legacy::th::_th_potrs_single(self, A, upper);
   }
 
   Tensor self_broadcasted, A_broadcasted;
   std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
-  return at::_potrs_helper(self_broadcasted, A_broadcasted, upper);
+  return at::_cholesky_solve_helper(self_broadcasted, A_broadcasted, upper);
 }
 
-Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
+Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
   AT_CHECK(self.dim() == 2 && A.dim() == 2,
-           "torch.potrs() with the `out` keyword does not support batching. "
+           "torch.cholesky_solve() with the `out` keyword does not support batching. "
            "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
   return at::legacy::th::_th_potrs_single_out(result, self, A, upper);
 }
index bf45bf7..9cbc870 100644 (file)
@@ -53,7 +53,7 @@ static inline double _get_epsilon(const ScalarType& sc_type) {
   }
 }
 
-// Validates input shapes for linear solve methods (gesv, potrs)
+// Validates input shapes for linear solve methods (gesv, cholesky_solve)
 static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
   AT_CHECK(A.size(-1) == A.size(-2),
            "A must be batches of square matrices, "
index 4b1c1a3..951c121 100644 (file)
@@ -44,10 +44,10 @@ void magmaGetriBatched(
 }
 
 template<class scalar_t>
-void magmaPotrsBatched(
+void magmaCholeskySolveBatched(
     magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
     scalar_t** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
-  AT_ERROR("potrs only takes float or double Tensors");
+  AT_ERROR("cholesky_solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -106,14 +106,14 @@ void magmaGetriBatched<float>(
 }
 
 template<>
-void magmaPotrsBatched<double>(
+void magmaCholeskySolveBatched<double>(
     magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
     double** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
     info = magma_dpotrs_batched(uplo, n, nrhs, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
 }
 
 template<>
-void magmaPotrsBatched<float>(
+void magmaCholeskySolveBatched<float>(
     magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
     float** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
     info = magma_spotrs_batched(uplo, n, nrhs, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
@@ -261,12 +261,12 @@ Tensor _inverse_helper_cuda(const Tensor& self) {
   return self_inv_working_copy;
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ potrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_potrs(Tensor& b, Tensor& A, bool upper, int64_t& info) {
+static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) {
 #ifndef USE_MAGMA
-AT_ERROR("potrs: MAGMA library not found in "
+AT_ERROR("cholesky_solve: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
 #else
   magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
@@ -281,13 +281,9 @@ AT_ERROR("potrs: MAGMA library not found in "
   magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
 
   magma_int_t info_tmp;
-  magma_int_t* ipiv_data;
-  magma_int_t** ipiv_array;
   scalar_t** A_array;
   scalar_t** b_array;
 
-  ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b);
-  ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b);
   ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
   ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
 
@@ -295,11 +291,10 @@ AT_ERROR("potrs: MAGMA library not found in "
   for (int64_t i = 0; i < batch_size; i++) {
     A_array[i] = &A_data[i * A_mat_stride];
     b_array[i] = &b_data[i * b_mat_stride];
-    ipiv_array[i] = &ipiv_data[i * n];
   }
 
   MAGMAQueue magma_queue(b.get_device());
-  magmaPotrsBatched<scalar_t>(
+  magmaCholeskySolveBatched<scalar_t>(
       uplo, n, nrhs, A_array, n, b_array, n,
       info_tmp, batch_size, magma_queue);
 
@@ -307,14 +302,14 @@ AT_ERROR("potrs: MAGMA library not found in "
 #endif
 }
 
-Tensor _potrs_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
+Tensor _cholesky_solve_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
   int64_t info = 0;
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
-  AT_DISPATCH_FLOATING_TYPES(self.type(), "potrs", [&]{
-    apply_potrs<scalar_t>(self_working_copy, A_working_copy, upper, info);
+  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
+    apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, info);
   });
-  AT_CHECK(info == 0, "MAGMA potrs : invalid argument: ", -info);
+  AT_CHECK(info == 0, "MAGMA cholesky_solve : invalid argument: ", -info);
   return self_working_copy;
 }
 
index 337c450..6f527f4 100644 (file)
     CPU: _cholesky_helper_cpu
     CUDA: _cholesky_helper_cuda
 
-- func: potrs_out(Tensor result, Tensor self, Tensor input2, bool upper=true) -> Tensor
+- func: cholesky_solve_out(Tensor result, Tensor self, Tensor input2, bool upper=false) -> Tensor
 
-- func: potrs(Tensor self, Tensor input2, bool upper=true) -> Tensor
+- func: cholesky_solve(Tensor self, Tensor input2, bool upper=false) -> Tensor
   variants: method, function
 
-- func: _potrs_helper(Tensor self, Tensor A, bool upper) -> Tensor
+- func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor
   variants: function
   dispatch:
-    CPU: _potrs_helper_cpu
-    CUDA: _potrs_helper_cuda
+    CPU: _cholesky_solve_helper_cpu
+    CUDA: _cholesky_solve_helper_cuda
 
 - func: potri_out(Tensor result, Tensor self, bool upper=true) -> Tensor
 
index 57c042a..977cf2b 100644 (file)
@@ -184,6 +184,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: ceil_
    .. automethod:: char
    .. automethod:: cholesky
+   .. automethod:: cholesky_solve
    .. automethod:: chunk
    .. automethod:: clamp
    .. automethod:: clamp_
index 1b42eb9..7447d37 100644 (file)
@@ -291,6 +291,7 @@ BLAS and LAPACK Operations
 .. autofunction:: btriunpack
 .. autofunction:: chain_matmul
 .. autofunction:: cholesky
+.. autofunction:: cholesky_solve
 .. autofunction:: dot
 .. autofunction:: eig
 .. autofunction:: gels
index 4d860cb..9ddb95a 100644 (file)
@@ -1603,16 +1603,16 @@ class TestCuda(TestCase):
         _TestTorchMixin._test_gesv_batched_dims(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_potrs(self):
-        _TestTorchMixin._test_potrs(self, lambda t: t.cuda())
+    def test_cholesky_solve(self):
+        _TestTorchMixin._test_cholesky_solve(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_potrs_batched(self):
-        _TestTorchMixin._test_potrs_batched(self, lambda t: t.cuda())
+    def test_cholesky_solve_batched(self):
+        _TestTorchMixin._test_cholesky_solve_batched(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_potrs_batched_dims(self):
-        _TestTorchMixin._test_potrs_batched_dims(self, lambda t: t.cuda())
+    def test_cholesky_solve_batched_dims(self):
+        _TestTorchMixin._test_cholesky_solve_batched_dims(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
     def test_cholesky(self):
index 37a4130..216ea85 100644 (file)
@@ -5500,7 +5500,7 @@ class _TestTorchMixin(object):
         self._test_cholesky_batched(self, lambda t: t)
 
     @staticmethod
-    def _test_potrs(self, cast):
+    def _test_cholesky_solve(self, cast):
         a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
                           (-6.05, -3.30, 5.36, -4.44, 1.08),
                           (-0.45, 2.58, -2.70, 0.27, 9.04),
@@ -5516,49 +5516,54 @@ class _TestTorchMixin(object):
 
         # upper Triangular Test
         U = torch.cholesky(a, True)
-        x = torch.potrs(b, U, True)
+        x = torch.cholesky_solve(b, U, True)
         self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
 
         # lower Triangular Test
         L = torch.cholesky(a, False)
-        x = torch.potrs(b, L, False)
+        x = torch.cholesky_solve(b, L, False)
         self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12)
 
+        # default arg Test
+        L_def = torch.cholesky(a)
+        x_def = torch.cholesky_solve(b, L_def)
+        self.assertLessEqual(b.dist(torch.mm(a, x_def)), 1e-12)
+
     @skipIfNoLapack
-    def test_potrs(self):
-        self._test_potrs(self, lambda t: t)
+    def test_cholesky_solve(self):
+        self._test_cholesky_solve(self, lambda t: t)
 
     @staticmethod
-    def _test_potrs_batched(self, cast):
+    def _test_cholesky_solve_batched(self, cast):
         from common_utils import random_symmetric_pd_matrix
 
-        def potrs_test_helper(A_dims, b_dims, cast, upper):
+        def cholesky_solve_test_helper(A_dims, b_dims, cast, upper):
             A = cast(random_symmetric_pd_matrix(*A_dims))
             L = torch.cholesky(A, upper)
             b = cast(torch.randn(*b_dims))
             return A, L, b
 
         for upper in [True, False]:
-            # test against potrs: one batch with both choices of upper
-            A, L, b = potrs_test_helper((5, 1), (1, 5, 10), cast, upper)
-            x_exp = torch.potrs(b.squeeze(0), L.squeeze(0), upper=upper)
-            x = torch.potrs(b, L, upper=upper)
+            # test against cholesky_solve: one batch with both choices of upper
+            A, L, b = cholesky_solve_test_helper((5, 1), (1, 5, 10), cast, upper)
+            x_exp = torch.cholesky_solve(b.squeeze(0), L.squeeze(0), upper=upper)
+            x = torch.cholesky_solve(b, L, upper=upper)
             self.assertEqual(x, x_exp.unsqueeze(0))
 
-            # test against potrs in a loop: four batches with both choices of upper
-            A, L, b = potrs_test_helper((5, 4), (4, 5, 10), cast, upper)
+            # test against cholesky_solve in a loop: four batches with both choices of upper
+            A, L, b = cholesky_solve_test_helper((5, 4), (4, 5, 10), cast, upper)
             x_exp_list = list()
             for i in range(4):
-                x_exp = torch.potrs(b[i], L[i], upper=upper)
+                x_exp = torch.cholesky_solve(b[i], L[i], upper=upper)
                 x_exp_list.append(x_exp)
             x_exp = torch.stack(x_exp_list)
 
-            x = torch.potrs(b, L, upper=upper)
+            x = torch.cholesky_solve(b, L, upper=upper)
             self.assertEqual(x, x_exp)
 
             # basic correctness test
-            A, L, b = potrs_test_helper((5, 3), (3, 5, 10), cast, upper)
-            x = torch.potrs(b, L, upper)
+            A, L, b = cholesky_solve_test_helper((5, 3), (3, 5, 10), cast, upper)
+            x = torch.cholesky_solve(b, L, upper)
             self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12)
 
             # Test non-contiguous inputs.
@@ -5573,15 +5578,15 @@ class _TestTorchMixin(object):
             b = cast(b).permute(2, 1, 0)
             assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
             L = torch.cholesky(A, upper)
-            x = torch.potrs(b, L, upper=upper)
+            x = torch.cholesky_solve(b, L, upper=upper)
             self.assertEqual(x, cast(x_exp))
 
     @skipIfNoLapack
-    def test_potrs_batched(self):
-        self._test_potrs_batched(self, lambda t: t)
+    def test_cholesky_solve_batched(self):
+        self._test_cholesky_solve_batched(self, lambda t: t)
 
     @staticmethod
-    def _test_potrs_batched_dims(self, cast):
+    def _test_cholesky_solve_batched_dims(self, cast):
         if not TEST_NUMPY:
             return
 
@@ -5594,7 +5599,7 @@ class _TestTorchMixin(object):
             x_exp = torch.Tensor(solve(A.numpy(), b.numpy()))
             A, b = cast(A), cast(b)
             L = torch.cholesky(A, upper)
-            x = torch.potrs(b, L, upper=upper)
+            x = torch.cholesky_solve(b, L, upper=upper)
             self.assertEqual(x, cast(x_exp))
 
         for upper in [True, False]:
@@ -5605,8 +5610,8 @@ class _TestTorchMixin(object):
             run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper)  # broadcasting A & b
 
     @skipIfNoLapack
-    def test_potrs_batched_dims(self):
-        self._test_potrs_batched_dims(self, lambda t: t)
+    def test_cholesky_solve_batched_dims(self):
+        self._test_cholesky_solve_batched_dims(self, lambda t: t)
 
     @skipIfNoLapack
     def test_potri(self):
index 99fe5f4..1cb5287 100644 (file)
 - name: cholesky(Tensor self, bool upper)
   self: cholesky_backward(grad, upper, result)
 
+- name: cholesky_solve(Tensor self, Tensor input2, bool upper)
+  self: not_implemented("cholesky_solve")
+  input2: not_implemented("cholesky_solve")
+
 # For clamp, gradient is not defined at the boundaries. But empirically it's helpful
 # to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
 - name: clamp(Tensor self, Scalar? min, Scalar? max)
 - name: potri(Tensor self, bool upper)
   self: not_implemented("potri")
 
-- name: potrs(Tensor self, Tensor input2, bool upper)
-  self: not_implemented("potrs")
-  input2: not_implemented("potrs")
-
 - name: pow(Tensor self, Scalar exponent)
   self: pow_backward(grad, self, exponent)
 
index a7a344e..08e9547 100644 (file)
@@ -539,6 +539,13 @@ cholesky(upper=False) -> Tensor
 See :func:`torch.cholesky`
 """)
 
+add_docstr_all('cholesky_solve',
+               r"""
+cholesky_solve(input2, upper=False) -> Tensor
+
+See :func:`torch.cholesky_solve`
+""")
+
 add_docstr_all('clamp',
                r"""
 clamp(min, max) -> Tensor
@@ -1698,13 +1705,6 @@ potri(upper=True) -> Tensor
 See :func:`torch.potri`
 """)
 
-add_docstr_all('potrs',
-               r"""
-potrs(input2, upper=True) -> Tensor
-
-See :func:`torch.potrs`
-""")
-
 add_docstr_all('pow',
                r"""
 pow(exponent) -> Tensor
index 90cf0d2..fa05483 100644 (file)
@@ -883,6 +883,67 @@ Example::
     tensor(2.3842e-07)
 """)
 
+add_docstr(torch.cholesky_solve, r"""
+cholesky_solve(b, u, upper=False, out=None) -> Tensor
+
+Solves a linear system of equations with a positive semidefinite
+matrix to be inverted given its Cholesky factor matrix :attr:`u`.
+
+If :attr:`upper` is ``False``, :attr:`u` is and lower triangular and `c` is
+returned such that:
+
+.. math::
+    c = (u u^T)^{-1} b
+
+If :attr:`upper` is ``True`` or not provided, :attr:`u` is upper triangular
+and `c` is returned such that:
+
+.. math::
+    c = (u^T u)^{-1} b
+
+`torch.cholesky_solve(b, u)` can take in 2D inputs `b, u` or inputs that are
+batches of 2D matrices. If the inputs are batches, then returns
+batched outputs `c`
+
+.. note::
+
+    The :attr:`out` keyword only supports 2D matrix inputs, that is,
+    `b, u` must be 2D matrices.
+
+Args:
+    b (Tensor): input matrix of size :math:`(*, m, k)`,
+                where :math:`*` is zero or more batch dimensions
+    u (Tensor): input matrix of size :math:`(*, m, m)`,
+                where :math:`*` is zero of more batch dimensions composed of
+                upper or lower triangular Cholesky factor
+    upper (bool, optional): whether to consider the Cholesky factor as a
+                            lower or upper triangular matrix. Default: ``False``.
+    out (Tensor, optional): the output tensor for `c`
+
+Example::
+
+    >>> a = torch.randn(3, 3)
+    >>> a = torch.mm(a, a.t()) # make symmetric positive definite
+    >>> u = torch.cholesky(a)
+    >>> a
+    tensor([[ 0.7747, -1.9549,  1.3086],
+            [-1.9549,  6.7546, -5.4114],
+            [ 1.3086, -5.4114,  4.8733]])
+    >>> b = torch.randn(3, 2)
+    >>> b
+    tensor([[-0.6355,  0.9891],
+            [ 0.1974,  1.4706],
+            [-0.4115, -0.6225]])
+    >>> torch.cholesky_solve(b, u)
+    tensor([[ -8.1625,  19.6097],
+            [ -5.8398,  14.2387],
+            [ -4.3771,  10.4173]])
+    >>> torch.mm(a.inverse(), b)
+    tensor([[ -8.1626,  19.6097],
+            [ -5.8398,  14.2387],
+            [ -4.3771,  10.4173]])
+""")
+
 add_docstr(torch.clamp,
            r"""
 clamp(input, min, max, out=None) -> Tensor
@@ -3415,66 +3476,6 @@ Example::
             [-0.0889,  0.2122,  0.1412]])
 """)
 
-add_docstr(torch.potrs, r"""
-potrs(b, u, upper=True, out=None) -> Tensor
-
-Solves a linear system of equations with a positive semidefinite
-matrix to be inverted given its Cholesky factor matrix :attr:`u`.
-
-If :attr:`upper` is ``True`` or not provided, :attr:`u` is upper triangular
-and `c` is returned such that:
-
-.. math::
-    c = (u^T u)^{-1} b
-
-If :attr:`upper` is ``False``, :attr:`u` is and lower triangular and `c` is
-returned such that:
-
-.. math::
-    c = (u u^T)^{-1} b
-
-`torch.potrs(b, u)` can take in 2D inputs `b, u` or inputs that are
-batches of 2D matrices. If the inputs are batches, then returns
-batched outputs `c`
-
-.. note::
-
-    The :attr:`out` keyword only supports 2D matrix inputs, that is,
-    `b, u` must be 2D matrices.
-
-Args:
-    b (Tensor): input matrix of size :math:`(*, m, k)`,
-                where :math:`*` is zero or more batch dimensions
-    u (Tensor): input matrix of size :math:`(*, m, m)`,
-                where :math:`*` is zero of more batch dimensions composed of
-                upper or lower triangular Cholesky factor
-    upper (bool, optional): whether to return a upper (default) or lower triangular matrix
-    out (Tensor, optional): the output tensor for `c`
-
-Example::
-
-    >>> a = torch.randn(3, 3)
-    >>> a = torch.mm(a, a.t()) # make symmetric positive definite
-    >>> u = torch.cholesky(a, upper=True)
-    >>> a
-    tensor([[ 0.7747, -1.9549,  1.3086],
-            [-1.9549,  6.7546, -5.4114],
-            [ 1.3086, -5.4114,  4.8733]])
-    >>> b = torch.randn(3, 2)
-    >>> b
-    tensor([[-0.6355,  0.9891],
-            [ 0.1974,  1.4706],
-            [-0.4115, -0.6225]])
-    >>> torch.potrs(b,u)
-    tensor([[ -8.1625,  19.6097],
-            [ -5.8398,  14.2387],
-            [ -4.3771,  10.4173]])
-    >>> torch.mm(a.inverse(),b)
-    tensor([[ -8.1626,  19.6097],
-            [ -5.8398,  14.2387],
-            [ -4.3771,  10.4173]])
-""")
-
 add_docstr(torch.pow,
            r"""
 .. function:: pow(input, exponent, out=None) -> Tensor
index 52d7db5..3c1a86d 100644 (file)
@@ -20,6 +20,7 @@ __all__ = [
     'norm',
     'meshgrid',
     'potrf',
+    'potrs',
     'split',
     'stft',
     'tensordot',
@@ -732,3 +733,20 @@ def potrf(a, upper=True, out=None):
                   "release. Please use torch.cholesky instead and note that the :attr:`upper` argument in"
                   " torch.cholesky defaults to ``False``.", stacklevel=2)
     return torch.cholesky(a, upper=upper, out=out)
+
+
+def potrs(b, u, upper=True, out=None):
+    r"""Solves a linear system of equations with a positive semidefinite
+    matrix to be inverted given its Cholesky factor matrix :attr:`u`.
+
+    For more information, regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`.
+
+    .. warning::
+        torch.potrs is deprecated in favour of torch.cholesky_solve and will be removed in the next
+        release. Please use torch.cholesky_solve instead and note that the :attr:`upper` argument in
+        torch.cholesky_solve defaults to ``False``.
+    """
+    warnings.warn("torch.potrs is deprecated in favour of torch.cholesky_solve and will be removed "
+                  "in the next release. Please use torch.cholesky instead and note that the "
+                  ":attr:`upper` argument in torch.cholesky_solve defaults to ``False``.", stacklevel=2)
+    return torch.cholesky_solve(b, u, upper=upper, out=out)
index 22bba1a..4d792bb 100644 (file)
@@ -258,6 +258,14 @@ class Tensor(torch._C._TensorBase):
                       ":attr:`upper` argument in torch.cholesky defaults to ``False``.", stacklevel=2)
         return super(Tensor, self).cholesky(upper=upper)
 
+    def potrs(self, u, upper=True):
+        r"""See :func:`torch.cholesky_solve`"""
+        warnings.warn("torch.potrs is deprecated in favour of torch.cholesky_solve and "
+                      "will be removed in the next release. Please use torch.cholesky_solve instead "
+                      "and note that the :attr:`upper` argument in torch.cholesky_solve defaults "
+                      "to ``False``.", stacklevel=2)
+        return super(Tensor, self).cholesky_solve(u, upper=upper)
+
     def stft(self, n_fft, hop_length=None, win_length=None, window=None,
              center=True, pad_mode='reflect', normalized=False, onesided=True):
         r"""See :func:`torch.stft`