Rename gesv to solve (#18060)
authorVishwak Srinivasan <cs15btech11043@iith.ac.in>
Mon, 18 Mar 2019 23:01:02 +0000 (16:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 18 Mar 2019 23:04:24 +0000 (16:04 -0700)
Summary:
Changelog:

- Renames `gesv` to `solve` to remain consistent with `cholesky_solve`.
- Rename all tests, fix callsites
- Create a tentative alias for `solve` under the name `gesv`, and add a deprecated warning to not promote usage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18060

Differential Revision: D14503117

Pulled By: zou3519

fbshipit-source-id: 99c16d94e5970a19d7584b5915f051c030d49ff5

21 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/common_methods_invocations.py
test/test_cuda.py
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/gen_python_functions.py
tools/autograd/templates/Functions.cpp
torch/_tensor_docs.py
torch/_torch_docs.py
torch/csrc/jit/passes/shape_analysis.cpp
torch/functional.py
torch/tensor.py

index 3266c46..458df59 100644 (file)
@@ -406,7 +406,6 @@ class CAFFE2_API Tensor {
   Tensor floor() const;
   Tensor & floor_();
   Tensor ger(const Tensor & vec2) const;
-  std::tuple<Tensor,Tensor> gesv(const Tensor & A) const;
   Tensor fft(int64_t signal_ndim, bool normalized=false) const;
   Tensor ifft(int64_t signal_ndim, bool normalized=false) const;
   Tensor rfft(int64_t signal_ndim, bool normalized=false, bool onesided=true) const;
@@ -696,6 +695,7 @@ class CAFFE2_API Tensor {
   std::tuple<Tensor,Tensor,Tensor> svd(bool some=true, bool compute_uv=true) const;
   Tensor cholesky(bool upper=false) const;
   Tensor cholesky_solve(const Tensor & input2, bool upper=false) const;
+  std::tuple<Tensor,Tensor> solve(const Tensor & A) const;
   Tensor potri(bool upper=true) const;
   std::tuple<Tensor,Tensor> pstrf(bool upper=true, Scalar tol=-1) const;
   std::tuple<Tensor,Tensor> qr() const;
index 9d6298b..ec96df8 100644 (file)
@@ -289,9 +289,6 @@ inline Tensor & Tensor::floor_() {
 inline Tensor Tensor::ger(const Tensor & vec2) const {
     return type().ger(*this, vec2);
 }
-inline std::tuple<Tensor,Tensor> Tensor::gesv(const Tensor & A) const {
-    return type().gesv(*this, A);
-}
 inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const {
     return type().fft(*this, signal_ndim, normalized);
 }
@@ -1159,6 +1156,9 @@ inline Tensor Tensor::cholesky(bool upper) const {
 inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const {
     return type().cholesky_solve(*this, input2, upper);
 }
+inline std::tuple<Tensor,Tensor> Tensor::solve(const Tensor & A) const {
+    return type().solve(*this, A);
+}
 inline Tensor Tensor::potri(bool upper) const {
     return type().potri(*this, upper);
 }
index e58d8e9..929b918 100644 (file)
@@ -284,7 +284,6 @@ struct CAFFE2_API Type {
   virtual Tensor floor(const Tensor & self) const = 0;
   virtual Tensor & floor_(Tensor & self) const = 0;
   virtual Tensor ger(const Tensor & self, const Tensor & vec2) const = 0;
-  virtual std::tuple<Tensor,Tensor> gesv(const Tensor & self, const Tensor & A) const = 0;
   virtual Tensor fft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0;
   virtual Tensor ifft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0;
   virtual Tensor rfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided) const = 0;
@@ -574,6 +573,7 @@ struct CAFFE2_API Type {
   virtual std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool compute_uv) const = 0;
   virtual Tensor cholesky(const Tensor & self, bool upper) const = 0;
   virtual Tensor cholesky_solve(const Tensor & self, const Tensor & input2, bool upper) const = 0;
+  virtual std::tuple<Tensor,Tensor> solve(const Tensor & self, const Tensor & A) const = 0;
   virtual Tensor potri(const Tensor & self, bool upper) const = 0;
   virtual std::tuple<Tensor,Tensor> pstrf(const Tensor & self, bool upper, Scalar tol) const = 0;
   virtual std::tuple<Tensor,Tensor> qr(const Tensor & self) const = 0;
index 29c90c2..ba6b64d 100644 (file)
@@ -40,8 +40,6 @@ _(aten, _cast_Long) \
 _(aten, _cast_Short) \
 _(aten, _cat) \
 _(aten, _ceil) \
-_(aten, _cholesky_helper) \
-_(aten, _cholesky_solve_helper) \
 _(aten, _convolution) \
 _(aten, _convolution_double_backward) \
 _(aten, _convolution_nogroup) \
@@ -80,10 +78,8 @@ _(aten, _fill) \
 _(aten, _floor) \
 _(aten, _fused_dropout) \
 _(aten, _ger) \
-_(aten, _gesv_helper) \
 _(aten, _indexCopy) \
 _(aten, _indices) \
-_(aten, _inverse_helper) \
 _(aten, _linspace) \
 _(aten, _local_scalar) \
 _(aten, _local_scalar_dense) \
@@ -343,7 +339,6 @@ _(aten, gels) \
 _(aten, geometric) \
 _(aten, geqrf) \
 _(aten, ger) \
-_(aten, gesv) \
 _(aten, get_device) \
 _(aten, glu) \
 _(aten, glu_backward) \
@@ -610,6 +605,7 @@ _(aten, softplus_forward) \
 _(aten, softshrink) \
 _(aten, softshrink_backward) \
 _(aten, softshrink_forward) \
+_(aten, solve) \
 _(aten, sort) \
 _(aten, sparse_coo_tensor) \
 _(aten, sparse_mask) \
index 31667be..ea8b1d4 100644 (file)
@@ -40,8 +40,8 @@ namespace native {
 // Define the per-batch functions to be used in the main implementation of the batched
 // linear algebra operations
 template<class scalar_t>
-void lapackGesv(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) {
-  AT_ERROR("gesv only takes float or double Tensors");
+void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) {
+  AT_ERROR("solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -65,11 +65,11 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) {
 }
 
 #ifdef USE_LAPACK
-template<> void lapackGesv<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
+template<> void lapackSolve<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
   dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
 }
 
-template<> void lapackGesv<float>(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
+template<> void lapackSolve<float>(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
   sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
 }
 
@@ -109,12 +109,12 @@ template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *
 // Below of the definitions of the functions operating on a batch that are going to be dispatched
 // in the main helper functions for the linear algebra operations
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template<typename scalar_t>
-static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
+static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 #ifndef USE_LAPACK
-  AT_ERROR("gesv: LAPACK library not found in compilation");
+  AT_ERROR("solve: LAPACK library not found in compilation");
 #else
   auto A_data = A.data<scalar_t>();
   auto b_data = b.data<scalar_t>();
@@ -125,7 +125,7 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 
   int info;
   if (b.dim() == 2) {
-    lapackGesv<scalar_t>(n, nrhs, A_data, n, ipiv.data<int>(), b_data, n, &info);
+    lapackSolve<scalar_t>(n, nrhs, A_data, n, ipiv.data<int>(), b_data, n, &info);
     infos[0] = info;
   } else {
     auto A_mat_stride = matrixStride(A);
@@ -135,7 +135,7 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
     for (int64_t i = 0; i < batch_size; i++) {
       scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
       scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
-      lapackGesv<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
+      lapackSolve<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
       infos[i] = info;
       if (info != 0) {
         return;
@@ -145,38 +145,35 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 #endif
 }
 
-std::tuple<Tensor, Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
+std::tuple<Tensor, Tensor> _solve_helper_cpu(const Tensor& self, const Tensor& A) {
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
   std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cpu", [&]{
-    apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cpu", [&]{
+    apply_solve<scalar_t>(self_working_copy, A_working_copy, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "gesv_cpu");
+    batchCheckErrors(infos, "solve_cpu");
   } else {
-    singleCheckErrors(infos[0], "gesv_cpu");
+    singleCheckErrors(infos[0], "solve_cpu");
   }
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
 
 // Supports arbitrary batch dimensions for self and A
-std::tuple<Tensor,Tensor> gesv(const Tensor& self, const Tensor& A) {
+std::tuple<Tensor,Tensor> solve(const Tensor& self, const Tensor& A) {
   AT_CHECK(self.dim() >= 2,
            "B should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
   AT_CHECK(A.dim() >= 2,
            "A should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
   Tensor self_broadcasted, A_broadcasted;
   std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
-  return at::_gesv_helper(self_broadcasted, A_broadcasted);
+  return at::_solve_helper(self_broadcasted, A_broadcasted);
 }
 
-std::tuple<Tensor&,Tensor&> gesv_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
-  AT_CHECK(self.dim() == 2 && A.dim() == 2,
-           "torch.gesv() with the `out` keyword does not support batching. "
-           "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
+std::tuple<Tensor&,Tensor&> solve_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) {
   Tensor solution_tmp, lu_tmp;
-  std::tie(solution_tmp, lu_tmp) = at::_gesv_helper(self, A);
+  std::tie(solution_tmp, lu_tmp) = at::_solve_helper(self, A);
   solution.resize_as_(solution_tmp).copy_(solution_tmp);
   lu.resize_as_(lu_tmp).copy_(lu_tmp);
   return std::tuple<Tensor&, Tensor&>(solution, lu);
index 4758966..0d95096 100644 (file)
@@ -75,7 +75,7 @@ static inline double _get_epsilon(const ScalarType& sc_type) {
   }
 }
 
-// Validates input shapes for linear solve methods (gesv, cholesky_solve)
+// Validates input shapes for linear solve methods (solve, cholesky_solve)
 static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
   AT_CHECK(A.size(-1) == A.size(-2),
            "A must be batches of square matrices, "
index 195df2d..dcadc87 100644 (file)
@@ -20,18 +20,18 @@ namespace native {
 
 #ifdef USE_MAGMA
 template<class scalar_t>
-void magmaGesv(
+void magmaSolve(
     magma_int_t n, magma_int_t nrhs, scalar_t* dA, magma_int_t ldda,
     magma_int_t* ipiv, scalar_t* dB, magma_int_t lddb, magma_int_t* info) {
-  AT_ERROR("gesv only takes float or double Tensors");
+  AT_ERROR("solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
-void magmaGesvBatched(
+void magmaSolveBatched(
     magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
     magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb,
     magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
-  AT_ERROR("gesv only takes float or double Tensors");
+  AT_ERROR("solve only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -86,7 +86,7 @@ void magmaCholeskyBatched(
 }
 
 template<>
-void magmaGesvBatched<double>(
+void magmaSolveBatched<double>(
     magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
     magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
     magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
@@ -94,7 +94,7 @@ void magmaGesvBatched<double>(
 }
 
 template<>
-void magmaGesvBatched<float>(
+void magmaSolveBatched<float>(
     magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
     magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
     magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) {
@@ -102,14 +102,14 @@ void magmaGesvBatched<float>(
 }
 
 template<>
-void magmaGesv<double>(
+void magmaSolve<double>(
     magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda,
     magma_int_t* ipiv, double* dB, magma_int_t lddb, magma_int_t* info) {
   magma_dgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info);
 }
 
 template<>
-void magmaGesv<float>(
+void magmaSolve<float>(
     magma_int_t n, magma_int_t nrhs, float* dA, magma_int_t ldda,
     magma_int_t* ipiv, float* dB, magma_int_t lddb, magma_int_t* info) {
   magma_sgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info);
@@ -222,12 +222,12 @@ void magmaCholeskyBatched<float>(
   auto storage_##name = pin_memory<type>(size, dummy_tensor); \
   name = static_cast<type*>(storage_##name.data());
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
+static void apply_solve(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 #ifndef USE_MAGMA
-AT_ERROR("gesv: MAGMA library not found in "
+AT_ERROR("solve: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
 #else
   auto A_data = A.data<scalar_t>();
@@ -238,7 +238,7 @@ AT_ERROR("gesv: MAGMA library not found in "
   if (b.dim() == 2) {
     auto ipiv = at::empty({n}, at::kInt);
     magma_int_t info = 0;
-    magmaGesv<scalar_t>(n, nrhs, A_data, n, ipiv.data<magma_int_t>(),
+    magmaSolve<scalar_t>(n, nrhs, A_data, n, ipiv.data<magma_int_t>(),
                         b_data, n, &info);
     infos[0] = info;
   } else {
@@ -266,7 +266,7 @@ AT_ERROR("gesv: MAGMA library not found in "
     }
 
     MAGMAQueue magma_queue(b.get_device());
-    magmaGesvBatched<scalar_t>(
+    magmaSolveBatched<scalar_t>(
         n, nrhs, A_array, n, ipiv_array, b_array, n,
         info_array, batch_size, magma_queue);
 
@@ -277,17 +277,17 @@ AT_ERROR("gesv: MAGMA library not found in "
 #endif
 }
 
-std::tuple<Tensor, Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A) {
+std::tuple<Tensor, Tensor> _solve_helper_cuda(const Tensor& self, const Tensor& A) {
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
   std::vector<int64_t> infos(batchCount(self), 0);
-  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cuda", [&]{
-    apply_gesv<scalar_t>(self_working_copy, A_working_copy, infos);
+  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cuda", [&]{
+    apply_solve<scalar_t>(self_working_copy, A_working_copy, infos);
   });
   if (self.dim() > 2) {
-    batchCheckErrors(infos, "gesv_cuda");
+    batchCheckErrors(infos, "solve_cuda");
   } else {
-    singleCheckErrors(infos[0], "gesv_cuda");
+    singleCheckErrors(infos[0], "solve_cuda");
   }
   return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
 }
index 0338060..cde92d8 100644 (file)
 - func: ger(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: gesv(Tensor self, Tensor A) -> (Tensor, Tensor)
-  matches_jit_signature: True
-  variants: function, method
-
-- func: gesv(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!))
-  matches_jit_signature: True
-
-# gesv handles broadcasting of arbitrary batch dims while _gesv_helper does not.
-- func: _gesv_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
-  matches_jit_signature: True
-  variants: function
-  dispatch:
-    CPU: _gesv_helper_cpu
-    CUDA: _gesv_helper_cuda
-
 - func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor
   matches_jit_signature: True
 
     CPU: _cholesky_solve_helper_cpu
     CUDA: _cholesky_solve_helper_cuda
 
+- func: solve(Tensor self, Tensor A) -> (Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function, method
+
+- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!))
+  matches_jit_signature: True
+
+- func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: _solve_helper_cpu
+    CUDA: _solve_helper_cuda
+
 - func: potri(Tensor self, bool upper=True, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
index cb0aae0..a7a3128 100644 (file)
@@ -369,6 +369,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: sinh_
    .. automethod:: size
    .. automethod:: slogdet
+   .. automethod:: solve
    .. automethod:: sort
    .. automethod:: split
    .. automethod:: sparse_mask
index 79ac07f..2bbeec9 100644 (file)
@@ -321,6 +321,7 @@ BLAS and LAPACK Operations
 .. autofunction:: potrs
 .. autofunction:: pstrf
 .. autofunction:: qr
+.. autofunction:: solve
 .. autofunction:: svd
 .. autofunction:: symeig
 .. autofunction:: trtrs
index 8463afa..307df3f 100644 (file)
@@ -649,15 +649,15 @@ def method_tests():
          'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
         ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
          'large', NO_ARGS, [skipIfNoLapack]),
-        ('gesv', (S, S), (random_fullrank_matrix_distinct_singular_value(
+        ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
             S, silent=True),), '', NO_ARGS, [skipIfNoLapack]),
-        ('gesv', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
+        ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
          'batched', NO_ARGS, [skipIfNoLapack]),
-        ('gesv', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),),
+        ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),),
          'batched_dims', NO_ARGS, [skipIfNoLapack]),
-        ('gesv', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),),
+        ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),),
          'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
-        ('gesv', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),),
+        ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),),
          'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
         ('fill_', (S, S, S), (1,), 'number'),
         ('fill_', (), (1,), 'number_scalar'),
index 98e1fe2..b89de2a 100644 (file)
@@ -2141,16 +2141,16 @@ class TestCuda(TestCase):
         _TestTorchMixin._test_det_logdet_slogdet(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_gesv(self):
-        _TestTorchMixin._test_gesv(self, lambda t: t.cuda())
+    def test_solve(self):
+        _TestTorchMixin._test_solve(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_gesv_batched(self):
-        _TestTorchMixin._test_gesv_batched(self, lambda t: t.cuda())
+    def test_solve_batched(self):
+        _TestTorchMixin._test_solve_batched(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_gesv_batched_dims(self):
-        _TestTorchMixin._test_gesv_batched_dims(self, lambda t: t.cuda())
+    def test_solve_batched_dims(self):
+        _TestTorchMixin._test_solve_batched_dims(self, lambda t: t.cuda())
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
     def test_cholesky_solve(self):
index 63b4c07..07da497 100644 (file)
@@ -4579,7 +4579,7 @@ class _TestTorchMixin(object):
         self.assertEqual(torch.cuda.HalfTensor(10).is_signed(), True)
 
     @staticmethod
-    def _test_gesv(self, cast):
+    def _test_solve(self, cast):
         a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
                                (-6.05, -3.30, 5.36, -4.44, 1.08),
                                (-0.45, 2.58, -2.70, 0.27, 9.04),
@@ -4589,63 +4589,63 @@ class _TestTorchMixin(object):
                                (-1.56, 4.00, -8.67, 1.75, 2.86),
                                (9.81, -4.09, -4.57, -8.61, 8.99)))).t()
 
-        res1 = torch.gesv(b, a)[0]
+        res1 = torch.solve(b, a)[0]
         self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12)
 
         ta = cast(torch.Tensor())
         tb = cast(torch.Tensor())
-        res2 = torch.gesv(b, a, out=(tb, ta))[0]
-        res3 = torch.gesv(b, a, out=(b, a))[0]
+        res2 = torch.solve(b, a, out=(tb, ta))[0]
+        res3 = torch.solve(b, a, out=(b, a))[0]
         self.assertEqual(res1, tb)
         self.assertEqual(res1, b)
         self.assertEqual(res1, res2)
         self.assertEqual(res1, res3)
 
         # test reuse
-        res1 = torch.gesv(b, a)[0]
+        res1 = torch.solve(b, a)[0]
         ta = cast(torch.Tensor())
         tb = cast(torch.Tensor())
-        torch.gesv(b, a, out=(tb, ta))[0]
+        torch.solve(b, a, out=(tb, ta))[0]
         self.assertEqual(res1, tb)
-        torch.gesv(b, a, out=(tb, ta))[0]
+        torch.solve(b, a, out=(tb, ta))[0]
         self.assertEqual(res1, tb)
 
     @skipIfNoLapack
-    def test_gesv(self):
-        self._test_gesv(self, lambda t: t)
+    def test_solve(self):
+        self._test_solve(self, lambda t: t)
 
     @staticmethod
-    def _test_gesv_batched(self, cast):
+    def _test_solve_batched(self, cast):
         from common_utils import random_fullrank_matrix_distinct_singular_value
-        # test against gesv: one batch
+        # test against solve: one batch
         A = cast(random_fullrank_matrix_distinct_singular_value(5, 1))
         b = cast(torch.randn(1, 5, 10))
-        x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0))
-        x, LU = torch.gesv(b, A)
+        x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0))
+        x, LU = torch.solve(b, A)
         self.assertEqual(x, x_exp.unsqueeze(0))
         self.assertEqual(LU, LU_exp.unsqueeze(0))
 
-        # test against gesv in a loop: four batches
+        # test against solve in a loop: four batches
         A = cast(random_fullrank_matrix_distinct_singular_value(5, 4))
         b = cast(torch.randn(4, 5, 10))
 
         x_exp_list = []
         LU_exp_list = []
         for i in range(4):
-            x_exp, LU_exp = torch.gesv(b[i], A[i])
+            x_exp, LU_exp = torch.solve(b[i], A[i])
             x_exp_list.append(x_exp)
             LU_exp_list.append(LU_exp)
         x_exp = torch.stack(x_exp_list)
         LU_exp = torch.stack(LU_exp_list)
 
-        x, LU = torch.gesv(b, A)
+        x, LU = torch.solve(b, A)
         self.assertEqual(x, x_exp)
         self.assertEqual(LU, LU_exp)
 
         # basic correctness test
         A = cast(random_fullrank_matrix_distinct_singular_value(5, 3))
         b = cast(torch.randn(3, 5, 10))
-        x, LU = torch.gesv(b, A)
+        x, LU = torch.solve(b, A)
         self.assertEqual(torch.matmul(A, x), b)
 
         # Test non-contiguous inputs.
@@ -4655,16 +4655,16 @@ class _TestTorchMixin(object):
         from numpy.linalg import solve
         A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2)
         b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0)
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
     @skipIfNoLapack
-    def test_gesv_batched(self):
-        self._test_gesv_batched(self, lambda t: t)
+    def test_solve_batched(self):
+        self._test_solve_batched(self, lambda t: t)
 
     @staticmethod
-    def _test_gesv_batched_dims(self, cast):
+    def _test_solve_batched_dims(self, cast):
         if not TEST_NUMPY:
             return
 
@@ -4673,7 +4673,7 @@ class _TestTorchMixin(object):
         # test against numpy.linalg.solve
         A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
         b = cast(torch.randn(2, 1, 3, 4, 6))
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
@@ -4682,34 +4682,34 @@ class _TestTorchMixin(object):
         b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1)
         assert not A.is_contiguous()
         assert not b.is_contiguous()
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
         # broadcasting b
         A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3))
         b = cast(torch.randn(4, 6))
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
         # broadcasting A
         A = cast(random_fullrank_matrix_distinct_singular_value(4))
         b = cast(torch.randn(2, 1, 3, 4, 2))
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
         # broadcasting both A & b
         A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1))
         b = cast(torch.randn(2, 1, 3, 4, 5))
-        x, _ = torch.gesv(b, A)
+        x, _ = torch.solve(b, A)
         x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy()))
         self.assertEqual(x.data, cast(x_exp))
 
     @skipIfNoLapack
-    def test_gesv_batched_dims(self):
-        self._test_gesv_batched_dims(self, lambda t: t)
+    def test_solve_batched_dims(self):
+        self._test_solve_batched_dims(self, lambda t: t)
 
     @skipIfNoLapack
     def test_qr(self):
index 8d8d6d8..8429c72 100644 (file)
   self: grad.mv(vec2)
   vec2: grad.t().mv(self)
 
-- name: gesv(Tensor self, Tensor A)
-  self: gesv_backward_self(grad, self, A)
-  A: gesv_backward_A(grad, self, A, result0)
-
 - name: indices(Tensor self)
   output_differentiability: [False]
 
   self: slogdet_backward(grad, self, result0, result1)
   output_differentiability: [false, true]
 
+- name: solve(Tensor self, Tensor A)
+  self: solve_backward_self(grad, self, A)
+  A: solve_backward_A(grad, self, A, result0)
+
 - name: sort(Tensor self, int64_t dim, bool descending)
   self: index_select_backward(grad, dim, result1, self.sizes(), true)
 
index 0066183..d95e6a0 100644 (file)
@@ -26,8 +26,8 @@ SKIP_PYTHON_BINDINGS = [
     '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
-    'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
-    '_potrs.*', '_cholesky.*', '_btrifact.*',
+    'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
+    '_cholesky.*', '_btrifact.*',
     'slice', 'randint(_out)?',
     'item', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
index 3db5552..8ed6eeb 100644 (file)
@@ -402,12 +402,12 @@ Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim, Sc
   return cumprod_backward(grad.to(input.scalar_type()), input, dim);
 }
 
-Tensor gesv_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
-  return std::get<0>(at::gesv(grad, A.transpose(-2, -1)));
+Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
+  return std::get<0>(at::solve(grad, A.transpose(-2, -1)));
 }
 
-Tensor gesv_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
-  Tensor grad_self = gesv_backward_self(grad, self, A);
+Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
+  Tensor grad_self = solve_backward_self(grad, self, A);
   if (self.ndimension() == 2 && A.ndimension() == 2) {
     return -at::mm(grad_self, solution.transpose(-2, -1));
   }
@@ -701,8 +701,8 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
 
   auto P = phi(at::matmul(L, Lbar));
   Tensor S;
-  std::tie(S, std::ignore) = at::gesv(P + P.transpose(-1, -2), L);
-  std::tie(S, std::ignore) = at::gesv(S.transpose(-1, -2), L);
+  std::tie(S, std::ignore) = at::solve(P + P.transpose(-1, -2), L);
+  std::tie(S, std::ignore) = at::solve(S.transpose(-1, -2), L);
   S = phi(S);
   if (upper) {
     S = S.transpose(-1, -2);
index c77c8f6..f769a52 100644 (file)
@@ -1019,11 +1019,11 @@ ger(vec2) -> Tensor
 See :func:`torch.ger`
 """)
 
-add_docstr_all('gesv',
+add_docstr_all('solve',
                r"""
-gesv(A) -> Tensor, Tensor
+solve(A) -> Tensor, Tensor
 
-See :func:`torch.gesv`
+See :func:`torch.solve`
 """)
 
 add_docstr_all('indices',
index 30311d1..b6d8f61 100644 (file)
@@ -2018,9 +2018,9 @@ Example::
             [  4.,   8.,  12.]])
 """)
 
-add_docstr(torch.gesv,
+add_docstr(torch.solve,
            r"""
-torch.gesv(B, A, out=None) -> (Tensor, Tensor)
+torch.solve(B, A, out=None) -> (Tensor, Tensor)
 
 This function returns the solution to the system of linear
 equations represented by :math:`AX = B` and the LU factorization of
@@ -2028,17 +2028,12 @@ A, in order as a tuple `X, LU`.
 
 `LU` contains `L` and `U` factors for LU factorization of `A`.
 
-`torch.gesv(B, A)` can take in 2D inputs `B, A` or inputs that are
+`torch.solve(B, A)` can take in 2D inputs `B, A` or inputs that are
 batches of 2D matrices. If the inputs are batches, then returns
 batched outputs `X, LU`.
 
 .. note::
 
-    The :attr:`out` keyword only supports 2D matrix inputs, that is,
-    `B, A` must be 2D matrices.
-
-.. note::
-
     Irrespective of the original strides, the returned matrices
     `X` and `LU` will be transposed, i.e. with strides like
     `B.contiguous().transpose(-1, -2).strides()` and
@@ -2061,7 +2056,7 @@ Example::
     >>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
                           [-1.56,  4.00, -8.67,  1.75,  2.86],
                           [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
-    >>> X, LU = torch.gesv(B, A)
+    >>> X, LU = torch.solve(B, A)
     >>> torch.dist(B, torch.mm(A, X))
     tensor(1.00000e-06 *
            7.0977)
@@ -2069,7 +2064,7 @@ Example::
     >>> # Batched solver example
     >>> A = torch.randn(2, 3, 1, 4, 4)
     >>> B = torch.randn(2, 3, 1, 4, 6)
-    >>> X, LU = torch.gesv(B, A)
+    >>> X, LU = torch.solve(B, A)
     >>> torch.dist(B, A.matmul(X))
     tensor(1.00000e-06 *
        3.6386)
index 893490e..fb4931f 100644 (file)
@@ -236,7 +236,7 @@ class ShapePropagator {
   }
 
   OperatorSet cannot_propagate_shape_by_running_it = {
-      "aten::gesv(Tensor self, Tensor A) -> (Tensor, Tensor)",
+      "aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)",
       "aten::inverse(Tensor self) -> Tensor",
   };
 
index 4540599..796d639 100644 (file)
@@ -23,8 +23,9 @@ __all__ = [
     'norm',
     'meshgrid',
     'potrf',
-    'potrs',
     'pstrf',
+    'potrs',
+    'gesv',
     'split',
     'stft',
     'tensordot',
@@ -739,7 +740,7 @@ def potrf(a, upper=True, out=None):
     r"""Computes the Cholesky decomposition of a symmetric positive-definite
     matrix :math:`A`.
 
-    For more information, regarding :func:`torch.potrf`, please check :func:`torch.cholesky`.
+    For more information regarding :func:`torch.potrf`, please check :func:`torch.cholesky`.
 
     .. warning::
         :func:`torch.potrf` is deprecated in favour of :func:`torch.cholesky` and will be removed
@@ -801,7 +802,7 @@ def potrs(b, u, upper=True, out=None):
     r"""Solves a linear system of equations with a positive semidefinite
     matrix to be inverted given its Cholesky factor matrix :attr:`u`.
 
-    For more information, regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`.
+    For more information regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`.
 
     .. warning::
         :func:`torch.potrs` is deprecated in favour of :func:`torch.cholesky_solve` and will be
@@ -812,3 +813,18 @@ def potrs(b, u, upper=True, out=None):
                   "in the next release. Please use torch.cholesky instead and note that the "
                   ":attr:`upper` argument in torch.cholesky_solve defaults to ``False``.", stacklevel=2)
     return torch.cholesky_solve(b, u, upper=upper, out=out)
+
+
+def gesv(b, A, out=None):
+    r"""This function returns the solution to the system of linear equations represented
+    by :math:`AX = B` and the LU factorization of A, in order as a tuple `X, LU`.
+
+    For more information regarding :func:`torch.gesv`, please check :func:`torch.solve`.
+
+    .. warning::
+        :func:`torch.gesv` is deprecated in favour of :func:`torch.solve` and will be removed in the
+        next release. Please use :func:`torch.solve` instead.
+    """
+    warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
+                  "next release. Please use torch.solve instead.", stacklevel=2)
+    return torch.solve(b, A, out=out)
index 2e69483..d78f2cf 100644 (file)
@@ -281,6 +281,12 @@ class Tensor(torch._C._TensorBase):
                       "to ``False``.", stacklevel=2)
         return super(Tensor, self).cholesky_solve(u, upper=upper)
 
+    def gesv(self, A):
+        r"""See :func:`torch.solve`"""
+        warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
+                      "next release. Please use torch.solve instead.", stacklevel=2)
+        return super(Tensor, self).solve(A)
+
     def stft(self, n_fft, hop_length=None, win_length=None, window=None,
              center=True, pad_mode='reflect', normalized=False, onesided=True):
         r"""See :func:`torch.stft`