From 421b508d55caa39a983be29e2e6ce79b91f9405e Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Mon, 18 Mar 2019 16:01:02 -0700 Subject: [PATCH] Rename gesv to solve (#18060) Summary: Changelog: - Renames `gesv` to `solve` to remain consistent with `cholesky_solve`. - Rename all tests, fix callsites - Create a tentative alias for `solve` under the name `gesv`, and add a deprecated warning to not promote usage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18060 Differential Revision: D14503117 Pulled By: zou3519 fbshipit-source-id: 99c16d94e5970a19d7584b5915f051c030d49ff5 --- aten/src/ATen/core/Tensor.h | 2 +- aten/src/ATen/core/TensorMethods.h | 6 +-- aten/src/ATen/core/Type.h | 2 +- aten/src/ATen/core/aten_interned_strings.h | 6 +-- aten/src/ATen/native/BatchLinearAlgebra.cpp | 39 ++++++++--------- aten/src/ATen/native/LinearAlgebraUtils.h | 2 +- aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 36 ++++++++-------- aten/src/ATen/native/native_functions.yaml | 29 +++++++------ docs/source/tensors.rst | 1 + docs/source/torch.rst | 1 + test/common_methods_invocations.py | 10 ++--- test/test_cuda.py | 12 +++--- test/test_torch.py | 56 ++++++++++++------------- tools/autograd/derivatives.yaml | 8 ++-- tools/autograd/gen_python_functions.py | 4 +- tools/autograd/templates/Functions.cpp | 12 +++--- torch/_tensor_docs.py | 6 +-- torch/_torch_docs.py | 15 +++---- torch/csrc/jit/passes/shape_analysis.cpp | 2 +- torch/functional.py | 22 ++++++++-- torch/tensor.py | 6 +++ 21 files changed, 144 insertions(+), 133 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 3266c46..458df59 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -406,7 +406,6 @@ class CAFFE2_API Tensor { Tensor floor() const; Tensor & floor_(); Tensor ger(const Tensor & vec2) const; - std::tuple gesv(const Tensor & A) const; Tensor fft(int64_t signal_ndim, bool normalized=false) const; Tensor ifft(int64_t signal_ndim, bool normalized=false) const; Tensor rfft(int64_t signal_ndim, bool normalized=false, bool onesided=true) const; @@ -696,6 +695,7 @@ class CAFFE2_API Tensor { std::tuple svd(bool some=true, bool compute_uv=true) const; Tensor cholesky(bool upper=false) const; Tensor cholesky_solve(const Tensor & input2, bool upper=false) const; + std::tuple solve(const Tensor & A) const; Tensor potri(bool upper=true) const; std::tuple pstrf(bool upper=true, Scalar tol=-1) const; std::tuple qr() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 9d6298b..ec96df8 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -289,9 +289,6 @@ inline Tensor & Tensor::floor_() { inline Tensor Tensor::ger(const Tensor & vec2) const { return type().ger(*this, vec2); } -inline std::tuple Tensor::gesv(const Tensor & A) const { - return type().gesv(*this, A); -} inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { return type().fft(*this, signal_ndim, normalized); } @@ -1159,6 +1156,9 @@ inline Tensor Tensor::cholesky(bool upper) const { inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { return type().cholesky_solve(*this, input2, upper); } +inline std::tuple Tensor::solve(const Tensor & A) const { + return type().solve(*this, A); +} inline Tensor Tensor::potri(bool upper) const { return type().potri(*this, upper); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index e58d8e9..929b918 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -284,7 +284,6 @@ struct CAFFE2_API Type { virtual Tensor floor(const Tensor & self) const = 0; virtual Tensor & floor_(Tensor & self) const = 0; virtual Tensor ger(const Tensor & self, const Tensor & vec2) const = 0; - virtual std::tuple gesv(const Tensor & self, const Tensor & A) const = 0; virtual Tensor fft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0; virtual Tensor ifft(const Tensor & self, int64_t signal_ndim, bool normalized) const = 0; virtual Tensor rfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided) const = 0; @@ -574,6 +573,7 @@ struct CAFFE2_API Type { virtual std::tuple svd(const Tensor & self, bool some, bool compute_uv) const = 0; virtual Tensor cholesky(const Tensor & self, bool upper) const = 0; virtual Tensor cholesky_solve(const Tensor & self, const Tensor & input2, bool upper) const = 0; + virtual std::tuple solve(const Tensor & self, const Tensor & A) const = 0; virtual Tensor potri(const Tensor & self, bool upper) const = 0; virtual std::tuple pstrf(const Tensor & self, bool upper, Scalar tol) const = 0; virtual std::tuple qr(const Tensor & self) const = 0; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 29c90c2..ba6b64d 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -40,8 +40,6 @@ _(aten, _cast_Long) \ _(aten, _cast_Short) \ _(aten, _cat) \ _(aten, _ceil) \ -_(aten, _cholesky_helper) \ -_(aten, _cholesky_solve_helper) \ _(aten, _convolution) \ _(aten, _convolution_double_backward) \ _(aten, _convolution_nogroup) \ @@ -80,10 +78,8 @@ _(aten, _fill) \ _(aten, _floor) \ _(aten, _fused_dropout) \ _(aten, _ger) \ -_(aten, _gesv_helper) \ _(aten, _indexCopy) \ _(aten, _indices) \ -_(aten, _inverse_helper) \ _(aten, _linspace) \ _(aten, _local_scalar) \ _(aten, _local_scalar_dense) \ @@ -343,7 +339,6 @@ _(aten, gels) \ _(aten, geometric) \ _(aten, geqrf) \ _(aten, ger) \ -_(aten, gesv) \ _(aten, get_device) \ _(aten, glu) \ _(aten, glu_backward) \ @@ -610,6 +605,7 @@ _(aten, softplus_forward) \ _(aten, softshrink) \ _(aten, softshrink_backward) \ _(aten, softshrink_forward) \ +_(aten, solve) \ _(aten, sort) \ _(aten, sparse_coo_tensor) \ _(aten, sparse_mask) \ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 31667be..ea8b1d4 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -40,8 +40,8 @@ namespace native { // Define the per-batch functions to be used in the main implementation of the batched // linear algebra operations template -void lapackGesv(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) { - AT_ERROR("gesv only takes float or double Tensors"); +void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info) { + AT_ERROR("solve only takes float or double Tensors"); } template @@ -65,11 +65,11 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) { } #ifdef USE_LAPACK -template<> void lapackGesv(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) { +template<> void lapackSolve(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) { dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); } -template<> void lapackGesv(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) { +template<> void lapackSolve(int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) { sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); } @@ -109,12 +109,12 @@ template<> void lapackCholesky(char uplo, int n, float *a, int lda, int * // Below of the definitions of the functions operating on a batch that are going to be dispatched // in the main helper functions for the linear algebra operations -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { +static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { #ifndef USE_LAPACK - AT_ERROR("gesv: LAPACK library not found in compilation"); + AT_ERROR("solve: LAPACK library not found in compilation"); #else auto A_data = A.data(); auto b_data = b.data(); @@ -125,7 +125,7 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { int info; if (b.dim() == 2) { - lapackGesv(n, nrhs, A_data, n, ipiv.data(), b_data, n, &info); + lapackSolve(n, nrhs, A_data, n, ipiv.data(), b_data, n, &info); infos[0] = info; } else { auto A_mat_stride = matrixStride(A); @@ -135,7 +135,7 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { for (int64_t i = 0; i < batch_size; i++) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; - lapackGesv(n, nrhs, A_working_ptr, n, ipiv.data(), b_working_ptr, n, &info); + lapackSolve(n, nrhs, A_working_ptr, n, ipiv.data(), b_working_ptr, n, &info); infos[i] = info; if (info != 0) { return; @@ -145,38 +145,35 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { #endif } -std::tuple _gesv_helper_cpu(const Tensor& self, const Tensor& A) { +std::tuple _solve_helper_cpu(const Tensor& self, const Tensor& A) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cpu", [&]{ - apply_gesv(self_working_copy, A_working_copy, infos); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cpu", [&]{ + apply_solve(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "gesv_cpu"); + batchCheckErrors(infos, "solve_cpu"); } else { - singleCheckErrors(infos[0], "gesv_cpu"); + singleCheckErrors(infos[0], "solve_cpu"); } return std::tuple(self_working_copy, A_working_copy); } // Supports arbitrary batch dimensions for self and A -std::tuple gesv(const Tensor& self, const Tensor& A) { +std::tuple solve(const Tensor& self, const Tensor& A) { AT_CHECK(self.dim() >= 2, "B should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); AT_CHECK(A.dim() >= 2, "A should have at least 2 dimensions, but has ", A.dim(), " dimensions instead"); Tensor self_broadcasted, A_broadcasted; std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A); - return at::_gesv_helper(self_broadcasted, A_broadcasted); + return at::_solve_helper(self_broadcasted, A_broadcasted); } -std::tuple gesv_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) { - AT_CHECK(self.dim() == 2 && A.dim() == 2, - "torch.gesv() with the `out` keyword does not support batching. " - "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); +std::tuple solve_out(Tensor& solution, Tensor& lu, const Tensor& self, const Tensor& A) { Tensor solution_tmp, lu_tmp; - std::tie(solution_tmp, lu_tmp) = at::_gesv_helper(self, A); + std::tie(solution_tmp, lu_tmp) = at::_solve_helper(self, A); solution.resize_as_(solution_tmp).copy_(solution_tmp); lu.resize_as_(lu_tmp).copy_(lu_tmp); return std::tuple(solution, lu); diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 4758966..0d95096 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -75,7 +75,7 @@ static inline double _get_epsilon(const ScalarType& sc_type) { } } -// Validates input shapes for linear solve methods (gesv, cholesky_solve) +// Validates input shapes for linear solve methods (solve, cholesky_solve) static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) { AT_CHECK(A.size(-1) == A.size(-2), "A must be batches of square matrices, " diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 195df2d..dcadc87 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -20,18 +20,18 @@ namespace native { #ifdef USE_MAGMA template -void magmaGesv( +void magmaSolve( magma_int_t n, magma_int_t nrhs, scalar_t* dA, magma_int_t ldda, magma_int_t* ipiv, scalar_t* dB, magma_int_t lddb, magma_int_t* info) { - AT_ERROR("gesv only takes float or double Tensors"); + AT_ERROR("solve only takes float or double Tensors"); } template -void magmaGesvBatched( +void magmaSolveBatched( magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb, magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { - AT_ERROR("gesv only takes float or double Tensors"); + AT_ERROR("solve only takes float or double Tensors"); } template @@ -86,7 +86,7 @@ void magmaCholeskyBatched( } template<> -void magmaGesvBatched( +void magmaSolveBatched( magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb, magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { @@ -94,7 +94,7 @@ void magmaGesvBatched( } template<> -void magmaGesvBatched( +void magmaSolveBatched( magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda, magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb, magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { @@ -102,14 +102,14 @@ void magmaGesvBatched( } template<> -void magmaGesv( +void magmaSolve( magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda, magma_int_t* ipiv, double* dB, magma_int_t lddb, magma_int_t* info) { magma_dgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info); } template<> -void magmaGesv( +void magmaSolve( magma_int_t n, magma_int_t nrhs, float* dA, magma_int_t ldda, magma_int_t* ipiv, float* dB, magma_int_t lddb, magma_int_t* info) { magma_sgesv_gpu(n, nrhs, dA, ldda, ipiv, dB, lddb, info); @@ -222,12 +222,12 @@ void magmaCholeskyBatched( auto storage_##name = pin_memory(size, dummy_tensor); \ name = static_cast(storage_##name.data()); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ gesv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_gesv(Tensor& b, Tensor& A, std::vector& infos) { +static void apply_solve(Tensor& b, Tensor& A, std::vector& infos) { #ifndef USE_MAGMA -AT_ERROR("gesv: MAGMA library not found in " +AT_ERROR("solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else auto A_data = A.data(); @@ -238,7 +238,7 @@ AT_ERROR("gesv: MAGMA library not found in " if (b.dim() == 2) { auto ipiv = at::empty({n}, at::kInt); magma_int_t info = 0; - magmaGesv(n, nrhs, A_data, n, ipiv.data(), + magmaSolve(n, nrhs, A_data, n, ipiv.data(), b_data, n, &info); infos[0] = info; } else { @@ -266,7 +266,7 @@ AT_ERROR("gesv: MAGMA library not found in " } MAGMAQueue magma_queue(b.get_device()); - magmaGesvBatched( + magmaSolveBatched( n, nrhs, A_array, n, ipiv_array, b_array, n, info_array, batch_size, magma_queue); @@ -277,17 +277,17 @@ AT_ERROR("gesv: MAGMA library not found in " #endif } -std::tuple _gesv_helper_cuda(const Tensor& self, const Tensor& A) { +std::tuple _solve_helper_cuda(const Tensor& self, const Tensor& A) { auto self_working_copy = cloneBatchedColumnMajor(self); auto A_working_copy = cloneBatchedColumnMajor(A); std::vector infos(batchCount(self), 0); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "gesv_cuda", [&]{ - apply_gesv(self_working_copy, A_working_copy, infos); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "solve_cuda", [&]{ + apply_solve(self_working_copy, A_working_copy, infos); }); if (self.dim() > 2) { - batchCheckErrors(infos, "gesv_cuda"); + batchCheckErrors(infos, "solve_cuda"); } else { - singleCheckErrors(infos[0], "gesv_cuda"); + singleCheckErrors(infos[0], "solve_cuda"); } return std::tuple(self_working_copy, A_working_copy); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0338060..cde92d8 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1088,21 +1088,6 @@ - func: ger(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True -- func: gesv(Tensor self, Tensor A) -> (Tensor, Tensor) - matches_jit_signature: True - variants: function, method - -- func: gesv(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!)) - matches_jit_signature: True - -# gesv handles broadcasting of arbitrary batch dims while _gesv_helper does not. -- func: _gesv_helper(Tensor self, Tensor A) -> (Tensor, Tensor) - matches_jit_signature: True - variants: function - dispatch: - CPU: _gesv_helper_cpu - CUDA: _gesv_helper_cuda - - func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor matches_jit_signature: True @@ -3770,6 +3755,20 @@ CPU: _cholesky_solve_helper_cpu CUDA: _cholesky_solve_helper_cuda +- func: solve(Tensor self, Tensor A) -> (Tensor, Tensor) + matches_jit_signature: True + variants: function, method + +- func: solve(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!), Tensor(b!)) + matches_jit_signature: True + +- func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: _solve_helper_cpu + CUDA: _solve_helper_cuda + - func: potri(Tensor self, bool upper=True, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index cb0aae0..a7a3128 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -369,6 +369,7 @@ view of a storage and defines numeric operations on it. .. automethod:: sinh_ .. automethod:: size .. automethod:: slogdet + .. automethod:: solve .. automethod:: sort .. automethod:: split .. automethod:: sparse_mask diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 79ac07f..2bbeec9 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -321,6 +321,7 @@ BLAS and LAPACK Operations .. autofunction:: potrs .. autofunction:: pstrf .. autofunction:: qr +.. autofunction:: solve .. autofunction:: svd .. autofunction:: symeig .. autofunction:: trtrs diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 8463afa..307df3f 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -649,15 +649,15 @@ def method_tests(): 'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])), ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS, 'large', NO_ARGS, [skipIfNoLapack]), - ('gesv', (S, S), (random_fullrank_matrix_distinct_singular_value( + ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value( S, silent=True),), '', NO_ARGS, [skipIfNoLapack]), - ('gesv', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),), + ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),), 'batched', NO_ARGS, [skipIfNoLapack]), - ('gesv', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),), + ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),), 'batched_dims', NO_ARGS, [skipIfNoLapack]), - ('gesv', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),), + ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]), - ('gesv', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),), + ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]), ('fill_', (S, S, S), (1,), 'number'), ('fill_', (), (1,), 'number_scalar'), diff --git a/test/test_cuda.py b/test/test_cuda.py index 98e1fe2..b89de2a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2141,16 +2141,16 @@ class TestCuda(TestCase): _TestTorchMixin._test_det_logdet_slogdet(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_gesv(self): - _TestTorchMixin._test_gesv(self, lambda t: t.cuda()) + def test_solve(self): + _TestTorchMixin._test_solve(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_gesv_batched(self): - _TestTorchMixin._test_gesv_batched(self, lambda t: t.cuda()) + def test_solve_batched(self): + _TestTorchMixin._test_solve_batched(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_gesv_batched_dims(self): - _TestTorchMixin._test_gesv_batched_dims(self, lambda t: t.cuda()) + def test_solve_batched_dims(self): + _TestTorchMixin._test_solve_batched_dims(self, lambda t: t.cuda()) @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_cholesky_solve(self): diff --git a/test/test_torch.py b/test/test_torch.py index 63b4c07..07da497 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4579,7 +4579,7 @@ class _TestTorchMixin(object): self.assertEqual(torch.cuda.HalfTensor(10).is_signed(), True) @staticmethod - def _test_gesv(self, cast): + def _test_solve(self, cast): a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), @@ -4589,63 +4589,63 @@ class _TestTorchMixin(object): (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99)))).t() - res1 = torch.gesv(b, a)[0] + res1 = torch.solve(b, a)[0] self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12) ta = cast(torch.Tensor()) tb = cast(torch.Tensor()) - res2 = torch.gesv(b, a, out=(tb, ta))[0] - res3 = torch.gesv(b, a, out=(b, a))[0] + res2 = torch.solve(b, a, out=(tb, ta))[0] + res3 = torch.solve(b, a, out=(b, a))[0] self.assertEqual(res1, tb) self.assertEqual(res1, b) self.assertEqual(res1, res2) self.assertEqual(res1, res3) # test reuse - res1 = torch.gesv(b, a)[0] + res1 = torch.solve(b, a)[0] ta = cast(torch.Tensor()) tb = cast(torch.Tensor()) - torch.gesv(b, a, out=(tb, ta))[0] + torch.solve(b, a, out=(tb, ta))[0] self.assertEqual(res1, tb) - torch.gesv(b, a, out=(tb, ta))[0] + torch.solve(b, a, out=(tb, ta))[0] self.assertEqual(res1, tb) @skipIfNoLapack - def test_gesv(self): - self._test_gesv(self, lambda t: t) + def test_solve(self): + self._test_solve(self, lambda t: t) @staticmethod - def _test_gesv_batched(self, cast): + def _test_solve_batched(self, cast): from common_utils import random_fullrank_matrix_distinct_singular_value - # test against gesv: one batch + # test against solve: one batch A = cast(random_fullrank_matrix_distinct_singular_value(5, 1)) b = cast(torch.randn(1, 5, 10)) - x_exp, LU_exp = torch.gesv(b.squeeze(0), A.squeeze(0)) - x, LU = torch.gesv(b, A) + x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0)) + x, LU = torch.solve(b, A) self.assertEqual(x, x_exp.unsqueeze(0)) self.assertEqual(LU, LU_exp.unsqueeze(0)) - # test against gesv in a loop: four batches + # test against solve in a loop: four batches A = cast(random_fullrank_matrix_distinct_singular_value(5, 4)) b = cast(torch.randn(4, 5, 10)) x_exp_list = [] LU_exp_list = [] for i in range(4): - x_exp, LU_exp = torch.gesv(b[i], A[i]) + x_exp, LU_exp = torch.solve(b[i], A[i]) x_exp_list.append(x_exp) LU_exp_list.append(LU_exp) x_exp = torch.stack(x_exp_list) LU_exp = torch.stack(LU_exp_list) - x, LU = torch.gesv(b, A) + x, LU = torch.solve(b, A) self.assertEqual(x, x_exp) self.assertEqual(LU, LU_exp) # basic correctness test A = cast(random_fullrank_matrix_distinct_singular_value(5, 3)) b = cast(torch.randn(3, 5, 10)) - x, LU = torch.gesv(b, A) + x, LU = torch.solve(b, A) self.assertEqual(torch.matmul(A, x), b) # Test non-contiguous inputs. @@ -4655,16 +4655,16 @@ class _TestTorchMixin(object): from numpy.linalg import solve A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2) b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0) - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) @skipIfNoLapack - def test_gesv_batched(self): - self._test_gesv_batched(self, lambda t: t) + def test_solve_batched(self): + self._test_solve_batched(self, lambda t: t) @staticmethod - def _test_gesv_batched_dims(self, cast): + def _test_solve_batched_dims(self, cast): if not TEST_NUMPY: return @@ -4673,7 +4673,7 @@ class _TestTorchMixin(object): # test against numpy.linalg.solve A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)) b = cast(torch.randn(2, 1, 3, 4, 6)) - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) @@ -4682,34 +4682,34 @@ class _TestTorchMixin(object): b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1) assert not A.is_contiguous() assert not b.is_contiguous() - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) # broadcasting b A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)) b = cast(torch.randn(4, 6)) - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) # broadcasting A A = cast(random_fullrank_matrix_distinct_singular_value(4)) b = cast(torch.randn(2, 1, 3, 4, 2)) - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) # broadcasting both A & b A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1)) b = cast(torch.randn(2, 1, 3, 4, 5)) - x, _ = torch.gesv(b, A) + x, _ = torch.solve(b, A) x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) self.assertEqual(x.data, cast(x_exp)) @skipIfNoLapack - def test_gesv_batched_dims(self): - self._test_gesv_batched_dims(self, lambda t: t) + def test_solve_batched_dims(self): + self._test_solve_batched_dims(self, lambda t: t) @skipIfNoLapack def test_qr(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8d8d6d8..8429c72 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -364,10 +364,6 @@ self: grad.mv(vec2) vec2: grad.t().mv(self) -- name: gesv(Tensor self, Tensor A) - self: gesv_backward_self(grad, self, A) - A: gesv_backward_A(grad, self, A, result0) - - name: indices(Tensor self) output_differentiability: [False] @@ -742,6 +738,10 @@ self: slogdet_backward(grad, self, result0, result1) output_differentiability: [false, true] +- name: solve(Tensor self, Tensor A) + self: solve_backward_self(grad, self, A) + A: solve_backward_A(grad, self, A, result0) + - name: sort(Tensor self, int64_t dim, bool descending) self: index_select_backward(grad, dim, result1, self.sizes(), true) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 0066183..d95e6a05 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -26,8 +26,8 @@ SKIP_PYTHON_BINDINGS = [ '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin', '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', - 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', - '_potrs.*', '_cholesky.*', '_btrifact.*', + 'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*', + '_cholesky.*', '_btrifact.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to', diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 3db5552..8ed6eeb 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -402,12 +402,12 @@ Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim, Sc return cumprod_backward(grad.to(input.scalar_type()), input, dim); } -Tensor gesv_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { - return std::get<0>(at::gesv(grad, A.transpose(-2, -1))); +Tensor solve_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) { + return std::get<0>(at::solve(grad, A.transpose(-2, -1))); } -Tensor gesv_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { - Tensor grad_self = gesv_backward_self(grad, self, A); +Tensor solve_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) { + Tensor grad_self = solve_backward_self(grad, self, A); if (self.ndimension() == 2 && A.ndimension() == 2) { return -at::mm(grad_self, solution.transpose(-2, -1)); } @@ -701,8 +701,8 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { auto P = phi(at::matmul(L, Lbar)); Tensor S; - std::tie(S, std::ignore) = at::gesv(P + P.transpose(-1, -2), L); - std::tie(S, std::ignore) = at::gesv(S.transpose(-1, -2), L); + std::tie(S, std::ignore) = at::solve(P + P.transpose(-1, -2), L); + std::tie(S, std::ignore) = at::solve(S.transpose(-1, -2), L); S = phi(S); if (upper) { S = S.transpose(-1, -2); diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index c77c8f6..f769a52 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1019,11 +1019,11 @@ ger(vec2) -> Tensor See :func:`torch.ger` """) -add_docstr_all('gesv', +add_docstr_all('solve', r""" -gesv(A) -> Tensor, Tensor +solve(A) -> Tensor, Tensor -See :func:`torch.gesv` +See :func:`torch.solve` """) add_docstr_all('indices', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 30311d1..b6d8f61 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2018,9 +2018,9 @@ Example:: [ 4., 8., 12.]]) """) -add_docstr(torch.gesv, +add_docstr(torch.solve, r""" -torch.gesv(B, A, out=None) -> (Tensor, Tensor) +torch.solve(B, A, out=None) -> (Tensor, Tensor) This function returns the solution to the system of linear equations represented by :math:`AX = B` and the LU factorization of @@ -2028,17 +2028,12 @@ A, in order as a tuple `X, LU`. `LU` contains `L` and `U` factors for LU factorization of `A`. -`torch.gesv(B, A)` can take in 2D inputs `B, A` or inputs that are +`torch.solve(B, A)` can take in 2D inputs `B, A` or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs `X, LU`. .. note:: - The :attr:`out` keyword only supports 2D matrix inputs, that is, - `B, A` must be 2D matrices. - -.. note:: - Irrespective of the original strides, the returned matrices `X` and `LU` will be transposed, i.e. with strides like `B.contiguous().transpose(-1, -2).strides()` and @@ -2061,7 +2056,7 @@ Example:: >>> B = torch.tensor([[4.02, 6.19, -8.22, -7.57, -3.03], [-1.56, 4.00, -8.67, 1.75, 2.86], [9.81, -4.09, -4.57, -8.61, 8.99]]).t() - >>> X, LU = torch.gesv(B, A) + >>> X, LU = torch.solve(B, A) >>> torch.dist(B, torch.mm(A, X)) tensor(1.00000e-06 * 7.0977) @@ -2069,7 +2064,7 @@ Example:: >>> # Batched solver example >>> A = torch.randn(2, 3, 1, 4, 4) >>> B = torch.randn(2, 3, 1, 4, 6) - >>> X, LU = torch.gesv(B, A) + >>> X, LU = torch.solve(B, A) >>> torch.dist(B, A.matmul(X)) tensor(1.00000e-06 * 3.6386) diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 893490e..fb4931f 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -236,7 +236,7 @@ class ShapePropagator { } OperatorSet cannot_propagate_shape_by_running_it = { - "aten::gesv(Tensor self, Tensor A) -> (Tensor, Tensor)", + "aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)", "aten::inverse(Tensor self) -> Tensor", }; diff --git a/torch/functional.py b/torch/functional.py index 4540599..796d639 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -23,8 +23,9 @@ __all__ = [ 'norm', 'meshgrid', 'potrf', - 'potrs', 'pstrf', + 'potrs', + 'gesv', 'split', 'stft', 'tensordot', @@ -739,7 +740,7 @@ def potrf(a, upper=True, out=None): r"""Computes the Cholesky decomposition of a symmetric positive-definite matrix :math:`A`. - For more information, regarding :func:`torch.potrf`, please check :func:`torch.cholesky`. + For more information regarding :func:`torch.potrf`, please check :func:`torch.cholesky`. .. warning:: :func:`torch.potrf` is deprecated in favour of :func:`torch.cholesky` and will be removed @@ -801,7 +802,7 @@ def potrs(b, u, upper=True, out=None): r"""Solves a linear system of equations with a positive semidefinite matrix to be inverted given its Cholesky factor matrix :attr:`u`. - For more information, regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`. + For more information regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`. .. warning:: :func:`torch.potrs` is deprecated in favour of :func:`torch.cholesky_solve` and will be @@ -812,3 +813,18 @@ def potrs(b, u, upper=True, out=None): "in the next release. Please use torch.cholesky instead and note that the " ":attr:`upper` argument in torch.cholesky_solve defaults to ``False``.", stacklevel=2) return torch.cholesky_solve(b, u, upper=upper, out=out) + + +def gesv(b, A, out=None): + r"""This function returns the solution to the system of linear equations represented + by :math:`AX = B` and the LU factorization of A, in order as a tuple `X, LU`. + + For more information regarding :func:`torch.gesv`, please check :func:`torch.solve`. + + .. warning:: + :func:`torch.gesv` is deprecated in favour of :func:`torch.solve` and will be removed in the + next release. Please use :func:`torch.solve` instead. + """ + warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the " + "next release. Please use torch.solve instead.", stacklevel=2) + return torch.solve(b, A, out=out) diff --git a/torch/tensor.py b/torch/tensor.py index 2e69483..d78f2cf 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -281,6 +281,12 @@ class Tensor(torch._C._TensorBase): "to ``False``.", stacklevel=2) return super(Tensor, self).cholesky_solve(u, upper=upper) + def gesv(self, A): + r"""See :func:`torch.solve`""" + warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the " + "next release. Please use torch.solve instead.", stacklevel=2) + return super(Tensor, self).solve(A) + def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): r"""See :func:`torch.stft` -- 2.7.4