From: Vishwak Srinivasan Date: Fri, 29 Mar 2019 07:27:48 +0000 (-0700) Subject: Rename `btrifact*` to `lu` (#18435) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~563 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d859031ebf5a4e45ad950c5cd53b1dc1df5c4136;p=platform%2Fupstream%2Fpytorch.git Rename `btrifact*` to `lu` (#18435) Summary: Changelog: - Renames `btrifact` and `btrifact_with_info` to `lu`to remain consistent with other factorization methods (`qr` and `svd`). - Now, we will only have one function and methods named `lu`, which performs `lu` decomposition. This function takes a get_infos kwarg, which when set to True includes a infos tensor in the tuple. - Rename all tests, fix callsites - Create a tentative alias for `lu` under the name `btrifact` and `btrifact_with_info`, and add a deprecation warning to not promote usage. - Add the single batch version for `lu` so that users don't have to unsqueeze and squeeze for a single square matrix (see changes in determinant computation in `LinearAlgebra.cpp`) Pull Request resolved: https://github.com/pytorch/pytorch/pull/18435 Differential Revision: D14680352 Pulled By: soumith fbshipit-source-id: af58dfc11fa53d9e8e0318c720beaf5502978cd8 --- diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 24b64b9..ea3f2b5 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -700,8 +700,6 @@ class CAFFE2_API Tensor { std::tuple geqrf() const; Tensor orgqr(const Tensor & input2) const; Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const; - std::tuple btrifact(bool pivot=true) const; - std::tuple btrifact_with_info(bool pivot=true) const; Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const; Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const; Tensor lgamma() const; diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index e153f06..2a05ce7 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1171,12 +1171,6 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const { inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { return type().ormqr(*this, input2, input3, left, transpose); } -inline std::tuple Tensor::btrifact(bool pivot) const { - return type().btrifact(*this, pivot); -} -inline std::tuple Tensor::btrifact_with_info(bool pivot) const { - return type().btrifact_with_info(*this, pivot); -} inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const { return type().btrisolve(*this, LU_data, LU_pivots); } diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 5b7a5d7..dcdd533 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -578,8 +578,6 @@ struct CAFFE2_API Type { virtual std::tuple geqrf(const Tensor & self) const = 0; virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0; virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0; - virtual std::tuple btrifact(const Tensor & self, bool pivot) const = 0; - virtual std::tuple btrifact_with_info(const Tensor & self, bool pivot) const = 0; virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0; virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0; virtual Tensor lgamma(const Tensor & self) const = 0; diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index b71ffe8..ca3fe63 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -88,6 +88,7 @@ _(aten, _log10) \ _(aten, _log1p) \ _(aten, _log2) \ _(aten, _logspace) \ +_(aten, _lu_with_info) \ _(aten, _masked_scale) \ _(aten, _mm) \ _(aten, _mv) \ @@ -224,8 +225,6 @@ _(aten, bincount) \ _(aten, blackman_window) \ _(aten, bmm) \ _(aten, broadcast_tensors) \ -_(aten, btrifact) \ -_(aten, btrifact_with_info) \ _(aten, btrisolve) \ _(aten, cartesian_prod) \ _(aten, cat) \ diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 46f6d47..3507279 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -51,8 +51,8 @@ void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, } template -void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) { - AT_ERROR("getrf only takes float or double Tensors"); +void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) { + AT_ERROR("lu only takes float or double Tensors"); } template @@ -92,11 +92,11 @@ template<> void lapackGetri(int n, float *a, int lda, int *ipiv, float *w sgetri_(&n, a, &lda, ipiv, work, &lwork, info); } -template<> void lapackGetrf(int m, int n, double *a, int lda, int *ipiv, int *info) { +template<> void lapackLu(int m, int n, double *a, int lda, int *ipiv, int *info) { dgetrf_(&m, &n, a, &lda, ipiv, info); } -template<> void lapackGetrf(int m, int n, float *a, int lda, int *ipiv, int *info) { +template<> void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { sgetrf_(&m, &n, a, &lda, ipiv, info); } @@ -219,7 +219,7 @@ static void apply_inverse(Tensor& self, std::vector& infos) { for (int64_t i = 0; i < batch_size; i++) { int info; scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - lapackGetrf(n, n, self_working_ptr, n, ipiv.data(), &info); + lapackLu(n, n, self_working_ptr, n, ipiv.data(), &info); infos[i] = info; if (info != 0) { return; @@ -406,41 +406,44 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) { +static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) { #ifndef USE_LAPACK - AT_ERROR("btrifact: LAPACK library not found in compilation"); + AT_ERROR("lu: LAPACK library not found in compilation"); #else auto self_data = self.data(); - auto self_matrix_stride = matrixStride(self); - auto batch_size = batchCount(self); - auto pivots_data = pivots.data(); - auto pivots_matrix_stride = pivots.size(-1); auto infos_data = infos.data(); auto n = self.size(-1); - for (int64_t i = 0; i < batch_size; i++) { - scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride]; - int* infos_working_ptr = &infos_data[i]; - lapackGetrf(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr); + if (self.dim() == 2) { + lapackLu(n, n, self_data, n, pivots_data, infos_data); + } else { + auto self_matrix_stride = matrixStride(self); + auto batch_size = batchCount(self); + auto pivots_matrix_stride = pivots.size(-1); + for (int64_t i = 0; i < batch_size; i++) { + scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; + int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride]; + int* infos_working_ptr = &infos_data[i]; + lapackLu(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr); + } } #endif } -std::tuple _btrifact_helper_cpu(const Tensor& self, bool pivot) { - AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU"); - AT_CHECK(self.dim() > 2, - "expected tensor with more than 2 dimensions, got size: ", self.sizes(), +std::tuple _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) { + AT_CHECK(pivot, "lu without pivoting is not implemented on the CPU"); + AT_CHECK(self.dim() >= 2, + "expected tensor with 2 or more dimensions, got size: ", self.sizes(), " instead"); squareCheckInputs(self); auto req_size = self.sizes().vec(); req_size.pop_back(); - auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt)); + auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt)); req_size.pop_back(); auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); @@ -449,55 +452,20 @@ std::tuple _btrifact_helper_cpu(const Tensor& self, bool self_working_copy = at::empty_like(self); } else { self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{ - apply_btrifact(self_working_copy, pivots_tensor, infos_tensor); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cpu", [&]{ + apply_lu(self_working_copy, pivots_tensor, infos_tensor); }); } + if (check_errors) { + if (self.dim() == 2) { + singleCheckErrors(infos_tensor.item(), "lu"); + } else { + batchCheckErrors(infos_tensor, "lu"); + } + } return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor); } -std::tuple btrifact(const Tensor& self, bool pivot) { - Tensor LU_fact, pivots, infos; - std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot); - batchCheckErrors(infos, "btrifact"); - return std::make_tuple(LU_fact, pivots); -} - -std::tuple btrifact_out( - Tensor& A_LU, - Tensor& pivots, - const Tensor& self, - bool pivot) { - Tensor infos, A_LU_tmp, pivots_tmp; - std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot); - batchCheckErrors(infos, "btrifact"); - A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp); - pivots.resize_as_(pivots_tmp).copy_(pivots_tmp); - return std::tuple(A_LU, pivots); -} - -std::tuple btrifact_with_info( - const Tensor& self, - bool pivot) { - Tensor LU_fact, pivots, infos; - std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot); - return std::make_tuple(LU_fact, pivots, infos); -} - -std::tuple btrifact_with_info_out( - Tensor& A_LU, - Tensor& pivots, - Tensor& info, - const Tensor& self, - bool pivot) { - Tensor info_tmp, A_LU_tmp, pivots_tmp; - std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot); - A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp); - pivots.resize_as_(pivots_tmp).copy_(pivots_tmp); - info.resize_as_(info_tmp).copy_(info_tmp); - return std::tuple(A_LU, pivots, info); -} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 5042acc..5d3c157 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -21,10 +21,8 @@ namespace native { // where info helps us identify singular matrices. static inline std::tuple _lu_det_P_diag_U_info(const Tensor& self) { Tensor p, lu, info; - std::tie(lu, p, info) = self.unsqueeze(0).btrifact_with_info(); - p.squeeze_(0); - lu.squeeze_(0); - int int_info = info.squeeze_().item(); + std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false); + int int_info = info.item(); AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info); auto n = self.size(0); auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0); diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index fc23d3c..3e470f5 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -109,7 +109,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) { " but each b matrix is ", self.size(-2), " by ", self.size(-1)); } -// Validates input shapes for operations on batches of square matrices (inverse, cholesky) +// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu) static inline void squareCheckInputs(const Tensor& self) { AT_CHECK(self.size(-1) == self.size(-2), "A must be batches of square matrices, " diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index f401995..1742981 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -35,18 +35,32 @@ void magmaSolveBatched( } template -void magmaGetrfBatched( +void magmaLu( + magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda, + magma_int_t* ipiv, magma_int_t* info) { + AT_ERROR("lu only takes float or double Tensors"); +} + +template +void magmaLuBatched( magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { - AT_ERROR("getrf only takes float or double Tensors"); + AT_ERROR("lu only takes float or double Tensors"); +} + +template +void magmaLuNoPiv( + magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda, + magma_int_t* info) { + AT_ERROR("lu only takes float or double Tensors"); } template -void magmaGetrfNoPivBatched( +void magmaLuNoPivBatched( magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { - AT_ERROR("getrf only takes float or double Tensors"); + AT_ERROR("lu only takes float or double Tensors"); } template @@ -131,7 +145,21 @@ void magmaSolveBatched( } template<> -void magmaGetrfBatched( +void magmaLu( + magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, + magma_int_t* ipiv, magma_int_t* info) { + magma_dgetrf_gpu(m, n, dA, ldda, ipiv, info); +} + +template<> +void magmaLu( + magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda, + magma_int_t* ipiv, magma_int_t* info) { + magma_sgetrf_gpu(m, n, dA, ldda, ipiv, info); +} + +template<> +void magmaLuBatched( magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { @@ -139,7 +167,7 @@ void magmaGetrfBatched( } template<> -void magmaGetrfBatched( +void magmaLuBatched( magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { @@ -147,14 +175,28 @@ void magmaGetrfBatched( } template<> -void magmaGetrfNoPivBatched( +void magmaLuNoPiv( + magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda, + magma_int_t* info) { + magma_dgetrf_nopiv_gpu(m, n, dA, ldda, info); +} + +template<> +void magmaLuNoPiv( + magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda, + magma_int_t* info) { + magma_sgetrf_nopiv_gpu(m, n, dA, ldda, info); +} + +template<> +void magmaLuNoPivBatched( magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue()); } template<> -void magmaGetrfNoPivBatched( +void magmaLuNoPivBatched( magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda, magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue()); @@ -373,7 +415,7 @@ AT_ERROR("inverse: MAGMA library not found in " } MAGMAQueue magma_queue(self.get_device()); - magmaGetrfBatched( + magmaLuBatched( n, n, self_array, n, ipiv_array, info_array, batch_size, magma_queue); @@ -527,75 +569,96 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) { } } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) { +static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) { #ifndef USE_MAGMA -AT_ERROR("btrifact: MAGMA library not found in " +AT_ERROR("lu: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else auto self_data = self.data(); - auto self_matrix_stride = matrixStride(self); - magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); magma_int_t n = magma_int_cast(self.size(-1), "n"); - scalar_t** self_array; - ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self); - - // Set up the created arrays - for (int64_t i = 0; i < batch_size; i++) { - self_array[i] = &self_data[i * self_matrix_stride]; - } + if (self.dim() == 2) { + // If `pivots` is defined, then we have to compute them. + // We will use the normal getrf function to compute the LU factorization + // and the pivots + // We create temporary tensors on the CPU, because tensors on the GPU + // cause segfault when passed to magmaLu and magmaLuNoPiv. The data is later + // copied to the appropriate tensors. + Tensor info_tmp = at::zeros({}, at::kInt); + if (get_pivots) { + Tensor piv_tmp = at::empty({n}, at::kInt); + magmaLu( + n, n, self_data, n, piv_tmp.data(), info_tmp.data()); + pivots.copy_(piv_tmp); + } else { + magmaLuNoPiv(n, n, self_data, n, info_tmp.data()); + } + infos.copy_(info_tmp); + } else { + auto self_matrix_stride = matrixStride(self); + magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); - MAGMAQueue magma_queue(self.get_device()); + scalar_t** self_array; + ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self); - // If `pivots` is defined, then we have to compute them. - // We will use the normal getrf function to compute the LU factorization - // and the pivots - if (pivots.defined()) { - auto pivots_data = pivots.data(); - auto pivots_matrix_stride = pivots.size(-1); - magma_int_t** pivots_array; - ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots); + // Set up the created arrays for (int64_t i = 0; i < batch_size; i++) { - pivots_array[i] = &pivots_data[i * pivots_matrix_stride]; + self_array[i] = &self_data[i * self_matrix_stride]; } - magmaGetrfBatched( - n, n, self_array, n, pivots_array, - infos.data(), batch_size, magma_queue); - } else { - magmaGetrfNoPivBatched( - n, n, self_array, n, infos.data(), - batch_size, magma_queue); + MAGMAQueue magma_queue(self.get_device()); + + // Same comment as in the case of single matrix above. + if (get_pivots) { + auto pivots_data = pivots.data(); + auto pivots_matrix_stride = pivots.size(-1); + magma_int_t** pivots_array; + ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots); + for (int64_t i = 0; i < batch_size; i++) { + pivots_array[i] = &pivots_data[i * pivots_matrix_stride]; + } + magmaLuBatched( + n, n, self_array, n, pivots_array, + infos.data(), batch_size, magma_queue); + } else { + magmaLuNoPivBatched( + n, n, self_array, n, infos.data(), + batch_size, magma_queue); + } } #endif } -std::tuple _btrifact_helper_cuda(const Tensor& self, bool pivot) { - AT_CHECK(self.dim() > 2, - "expected tensor with more than 2 dimensions, got size: ", self.sizes(), +std::tuple _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) { + AT_CHECK(self.dim() >= 2, + "expected tensor with 2 or more dimensions, got size: ", self.sizes(), " instead"); squareCheckInputs(self); auto req_size = self.sizes().vec(); req_size.pop_back(); - Tensor pivots_tensor; - if (pivot) { - pivots_tensor = at::zeros(req_size, self.options().dtype(kInt)); - } + Tensor pivots_tensor = at::zeros(req_size, self.options().dtype(at::kInt)); req_size.pop_back(); - auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); + auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt)); Tensor self_working_copy; if (self.numel() == 0) { self_working_copy = at::empty_like(self); } else { self_working_copy = cloneBatchedColumnMajor(self); - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{ - apply_btrifact(self_working_copy, pivots_tensor, infos_tensor); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{ + apply_lu(self_working_copy, pivots_tensor, infos_tensor, pivot); }); } + if (check_errors) { + if (self.dim() == 2) { + singleCheckErrors(infos_tensor.item(), "lu"); + } else { + batchCheckErrors(infos_tensor, "lu"); + } + } return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 152a2ae..e4200ab 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3818,26 +3818,12 @@ matches_jit_signature: True variants: method, function -- func: btrifact(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots) -> (Tensor(a!), Tensor(b!)) - matches_jit_signature: True - -- func: btrifact(Tensor self, *, bool pivot=True) -> (Tensor, Tensor) - matches_jit_signature: True - variants: method, function - -- func: btrifact_with_info(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!), Tensor(b!), Tensor(c!)) - matches_jit_signature: True - -- func: btrifact_with_info(Tensor self, *, bool pivot=True) -> (Tensor, Tensor, Tensor) - matches_jit_signature: True - variants: method, function - -- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor) +- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor) matches_jit_signature: True variants: function dispatch: - CPU: _btrifact_helper_cpu - CUDA: _btrifact_helper_cuda + CPU: _lu_with_info_cpu + CUDA: _lu_with_info_cuda - func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 2bba064..b9ad547 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it. .. automethod:: long .. automethod:: lt .. automethod:: lt_ + .. automethod:: lu .. automethod:: map_ .. automethod:: masked_scatter_ .. automethod:: masked_scatter diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 1013128..30a09dc 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -315,6 +315,7 @@ BLAS and LAPACK Operations .. autofunction:: det .. autofunction:: logdet .. autofunction:: slogdet +.. autofunction:: lu .. autofunction:: matmul .. autofunction:: matrix_power .. autofunction:: matrix_rank diff --git a/test/test_cuda.py b/test/test_cuda.py index 72a3157..e919cf9 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2361,8 +2361,8 @@ class TestCuda(TestCase): @skipIfRocm @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_btrifact(self): - _TestTorchMixin._test_btrifact(self, lambda t: t.cuda()) + def test_lu(self): + _TestTorchMixin._test_lu(self, lambda t: t.cuda()) @skipIfRocm @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") diff --git a/test/test_torch.py b/test/test_torch.py index 93da20c..bd0ac1a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1746,44 +1746,43 @@ class _TestTorchMixin(object): _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64)) @staticmethod - def _test_btrifact(self, cast): + def _test_lu(self, cast): from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank def run_test(matrix_size, batches, cast): a = cast(fullrank(matrix_size, *batches)) - a_LU_info, pivots_info, info_ = a.btrifact_with_info() + a_LU_info, pivots_info, info_ = a.lu(get_infos=True) self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size))) self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,))) self.assertEqual(info_.size(), torch.Size(batches)) self.assertEqual(info_.abs().sum(), 0) - a_LU, pivots = a.btrifact() + a_LU, pivots = a.lu() self.assertEqual(a_LU, a_LU_info) self.assertEqual(pivots_info, pivots) if a.is_cuda: - a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False) - self.assertIsNone(nopiv) + a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) + self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32))) self.assertEqual(info_, info_nopiv) P, L, U = torch.btriunpack(a_LU, pivots) self.assertEqual(P.matmul(L.matmul(U)), a) - for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]): + for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]): run_test(ms, batch, cast) # Info should be positive for rank deficient matrices - a = cast(fullrank(3, 5)) + a = cast(torch.ones(5, 3, 3)) if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])): - a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency - self.assertGreater(a.btrifact_with_info()[2][0], 0) + self.assertGreater(a.lu(get_infos=True)[2][0], 0) # Error checking, no pivoting variant on CPU with self.assertRaisesRegex(RuntimeError, - 'btrifact without pivoting is not implemented on the CPU'): - torch.btrifact(torch.empty(1, 2, 2), pivot=False) + 'lu without pivoting is not implemented on the CPU'): + torch.lu(torch.empty(1, 2, 2), pivot=False) @skipIfNoLapack @skipIfRocm - def test_btrifact(self): - self._test_btrifact(self, lambda t: t) + def test_lu(self): + self._test_lu(self, lambda t: t) @staticmethod def _test_btrisolve(self, cast): @@ -1797,7 +1796,7 @@ class _TestTorchMixin(object): (-1.56, 4.00), (9.81, -4.09))) a, b = cast(a), cast(b) - LU_data, pivots, info = a.btrifact_with_info() + LU_data, pivots, info = a.lu(get_infos=True) self.assertEqual(info.abs().sum(), 0) x = torch.btrisolve(b, LU_data, pivots) b_ = torch.bmm(a, x.unsqueeze(2)).squeeze() @@ -1811,12 +1810,11 @@ class _TestTorchMixin(object): def _test_btriunpack(self, cast): def run_test(shape, cast): a = cast(torch.randn(*shape)) - a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1])) - a_lu = a_lu.reshape_as(a) - p = p.reshape(a.shape[:-1]) + a_lu, p = torch.lu(a) p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p) self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) + run_test((3, 3), cast) run_test((5, 3, 3), cast) run_test((7, 3, 5, 5), cast) run_test((7, 5, 3, 3, 3), cast) @@ -4743,11 +4741,14 @@ class _TestTorchMixin(object): A = torch.randn(3, 3, device=A_device) err_str = "Expected b and A to be on the same device" with self.assertRaisesRegex(RuntimeError, err_str): - torch.gesv(b, A) + torch.solve(b, A) with self.assertRaisesRegex(RuntimeError, err_str): torch.cholesky_solve(b, A) + with self.assertRaisesRegex(RuntimeError, err_str): + torch.triangular_solve(b, A) + @skipIfNoLapack def test_qr(self): @@ -7965,12 +7966,12 @@ class _TestTorchMixin(object): self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) if torch._C.has_lapack: - # btrifact - A_LU, pivots = fn(torch.btrifact, (0, 5, 5)) + # lu + A_LU, pivots = fn(torch.lu, (0, 5, 5)) self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.btrifact, (0, 0, 0)) + A_LU, pivots = fn(torch.lu, (0, 0, 0)) self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.btrifact, (2, 0, 0)) + A_LU, pivots = fn(torch.lu, (2, 0, 0)) self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) @skipIfRocm diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 4a59185..425a5b5 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -192,12 +192,6 @@ self: grad.bmm(mat2.transpose(1, 2)) mat2: self.transpose(1, 2).bmm(grad) -- name: btrifact(Tensor self, bool pivot) - self: not_implemented("btrifact") - -- name: btrifact_with_info(Tensor self, bool pivot) - self: not_implemented("btrifact_with_info") - - name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots) self: not_implemented("btrisolve") @@ -470,6 +464,9 @@ self: zeros_like(self) other: zeros_like(other) +- name: _lu_with_info(Tensor self, bool pivot, bool check_errors) + self: not_implemented("lu_with_info") + - name: masked_fill_(Tensor self, Tensor mask, Scalar value) self: grad.clone().masked_fill_(mask, 0) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 80f7495..7201f77 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [ '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', 'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*', - '_cholesky.*', '_btrifact.*', '_triangular_solve.*', + '_cholesky.*', '_triangular_solve.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to', diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index 0d842f5..141fbe0 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -78,6 +78,7 @@ class Tensor: center=True, pad_mode='reflect', normalized=False, onesided=True): ... def split(self, split_size, dim=0): ... def unique(self, sorted=True, return_inverse=False, dim=None): ... + def lu(self, pivot=True, get_infos=False): ... ${function_hints} diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 289c3f3..bc90d94 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -486,20 +486,6 @@ bmm(batch2) -> Tensor See :func:`torch.bmm` """) -add_docstr_all('btrifact', - r""" -btrifact(pivot=True) -> (Tensor, Tensor) - -See :func:`torch.btrifact` -""") - -add_docstr_all('btrifact_with_info', - r""" -btrifact_with_info(pivot=True) -> (Tensor, Tensor, Tensor) - -See :func:`torch.btrifact_with_info` -""") - add_docstr_all('btrisolve', r""" btrisolve(LU_data, LU_pivots) -> Tensor @@ -1019,13 +1005,6 @@ ger(vec2) -> Tensor See :func:`torch.ger` """) -add_docstr_all('solve', - r""" -solve(A) -> Tensor, Tensor - -See :func:`torch.solve` -""") - add_docstr_all('indices', r""" indices() -> Tensor @@ -2228,6 +2207,13 @@ Example:: """) +add_docstr_all('solve', + r""" +solve(A) -> Tensor, Tensor + +See :func:`torch.solve` +""") + add_docstr_all('sort', r""" sort(dim=-1, descending=False) -> (Tensor, LongTensor) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 3249126..1a9262b 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5516,71 +5516,6 @@ Example:: [ 0., 0., 0.]]) """.format(**factory_like_common_args)) -add_docstr(torch.btrifact, - r""" -btrifact(A, pivot=True) -> (Tensor, IntTensor) - -Batch LU factorization. - -Returns a tuple containing the LU factorization and pivots. Pivoting is done if -:attr:`pivot` is set. - -.. note:: - LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting - to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is - available for CUDA. - -Arguments: - A (Tensor): the tensor to factor - pivot (bool, optional): controls whether pivoting is done - -Returns: - A tuple containing factorization and pivots. - -Example:: - - >>> A = torch.randn(2, 3, 3) - >>> A_LU, pivots = torch.btrifact(A) - >>> A_LU - tensor([[[ 1.3506, 2.5558, -0.0816], - [ 0.1684, 1.1551, 0.1940], - [ 0.1193, 0.6189, -0.5497]], - - [[ 0.4526, 1.2526, -0.3285], - [-0.7988, 0.7175, -0.9701], - [ 0.2634, -0.9255, -0.3459]]]) - - >>> pivots - tensor([[ 3, 3, 3], - [ 3, 3, 3]], dtype=torch.int32) -""") - -add_docstr(torch.btrifact_with_info, - r""" -btrifact_with_info(A, pivot=True) -> (Tensor, IntTensor, IntTensor) - -Batch LU factorization with additional error information. - -This is a version of :meth:`torch.btrifact` that always creates an info -`IntTensor`, and returns it as the third return value. - -Arguments: - A (Tensor): the tensor to factor - pivot (bool, optional): controls whether pivoting is done - -Returns: - A tuple containing factorization, pivots, and an `IntTensor` where non-zero - values indicate whether factorization for each minibatch sample succeeds. - -Example:: - - >>> A = torch.randn(2, 3, 3) - >>> A_LU, pivots, info = A.btrifact_with_info() - >>> if info.nonzero().size(0) == 0: - >>> print('LU factorization succeeded for all samples!') - LU factorization succeeded for all samples! -""") - add_docstr(torch.btrisolve, r""" btrisolve(b, LU_data, LU_pivots) -> Tensor @@ -5591,14 +5526,14 @@ Returns the LU solve of the linear system :math:`Ax = b`. Arguments: b (Tensor): the RHS tensor - LU_data (Tensor): the pivoted LU factorization of A from :meth:`btrifact`. + LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`. LU_pivots (IntTensor): the pivots of the LU factorization Example:: >>> A = torch.randn(2, 3, 3) >>> b = torch.randn(2, 3) - >>> A_LU = torch.btrifact(A) + >>> A_LU = torch.lu(A) >>> x = torch.btrisolve(b, *A_LU) >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2)) tensor(1.00000e-07 * diff --git a/torch/functional.py b/torch/functional.py index 580227c..2fcb46a 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -6,23 +6,26 @@ import warnings __all__ = [ 'btriunpack', + 'broadcast_tensors', + 'btrifact', + 'btrifact_with_info', + 'cartesian_prod', 'chain_matmul', 'einsum', - 'broadcast_tensors', + 'gesv', 'isfinite', 'isinf', + 'lu', 'norm', 'meshgrid', 'potrf', 'pstrf', 'potrs', - 'gesv', 'split', 'stft', 'tensordot', 'trtrs', 'unique', - 'cartesian_prod', ] @@ -81,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0): def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): - r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. + r"""Unpacks the data and pivots from a LU factorization of a tensor. Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``. @@ -94,7 +97,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): Example:: >>> A = torch.randn(2, 3, 3) - >>> A_LU, pivots = A.btrifact() + >>> A_LU, pivots = A.lu() >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots) >>> >>> # can recover A from factorization @@ -111,13 +114,20 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): L = U = None if unpack_pivots: - P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone() - LU_pivots = LU_pivots - 1 - for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])): + LU_pivots_zero_idx = LU_pivots - 1 + if LU_data.dim() > 2: + P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone() + for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])): + final_order = list(range(sz)) + for k, j in enumerate(LU_pivots_zero_idx[idx]): + final_order[k], final_order[j] = final_order[j], final_order[k] + P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)) + else: + P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype) final_order = list(range(sz)) - for k, j in enumerate(LU_pivots[idx]): + for k, j, in enumerate(LU_pivots_zero_idx): final_order[k], final_order[j] = final_order[j], final_order[k] - P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)) + P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)) else: P = None @@ -751,6 +761,8 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None): In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular with the default keyword arguments. + For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`. + .. warning:: :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be removed in the next release. Please use :func:`torch.triangular_solve` instead. @@ -758,3 +770,112 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None): warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be " "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2) return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out) + + +def btrifact(A, pivot=True, out=None): + r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`. + Pivoting is done if :attr:`pivot` is set. + + For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`. + + .. warning:: + :func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be + removed in the next release. Please use :func:`torch.lu` instead. + """ + warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be " + "removed in the next release. Please use torch.lu instead.", stacklevel=2) + return lu(A, pivot=pivot, get_infos=False, out=out) + + +def btrifact_with_info(A, pivot=True, out=None): + r"""Performs LU factorization and returns additional status information along with the LU + factorization and pivots. + + For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`. + + .. warning:: + :func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will + be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos` + argument set to ``True`` instead. + """ + warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be " + "removed in the next release. Please use torch.lu with the get_infos argument " + "set to True instead.", + stacklevel=2) + return lu(A, pivot=pivot, get_infos=True, out=out) + + +def lu(A, pivot=True, get_infos=False, out=None): + r"""Computes the LU factorization of a square matrix or batches of square matrices + :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`. + Pivoting is done if :attr:`pivot` is set to ``True``. + + .. note:: + The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``, + then the returned pivots is a tensor filled with zeros of the appropriate size. + + .. note:: + LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting + to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is + available for CUDA. + + .. note:: + This function does not check if the factorization was successful or not if + :attr:`get_infos` is ``True`` since the status of the factorization is present in the + third element of the return tuple. + + Arguments: + A (Tensor): the tensor to factor of size :math:`(*, m, m)` + pivot (bool, optional): controls whether pivoting is done. Default: ``True`` + get_infos (bool, optional): if set to ``True``, returns an info IntTensor. + Default: ``False`` + out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``, + then the elements in the tuple are Tensor, IntTensor, + and IntTensor. If :attr:`get_infos` is ``False``, then the + elements in the tuple are Tensor, IntTensor. Default: ``None`` + + Returns: + (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing + + - **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)` + + - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)` + + - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of + size :math:`(*)` where non-zero values indicate whether factorization for the matrix or + each minibatch has succeeded or failed + + Example:: + + >>> A = torch.randn(2, 3, 3) + >>> A_LU, pivots = torch.lu(A) + >>> A_LU + tensor([[[ 1.3506, 2.5558, -0.0816], + [ 0.1684, 1.1551, 0.1940], + [ 0.1193, 0.6189, -0.5497]], + + [[ 0.4526, 1.2526, -0.3285], + [-0.7988, 0.7175, -0.9701], + [ 0.2634, -0.9255, -0.3459]]]) + >>> pivots + tensor([[ 3, 3, 3], + [ 3, 3, 3]], dtype=torch.int32) + >>> A_LU, pivots, info = torch.lu(A, get_infos=True) + >>> if info.nonzero().size(0) == 0: + ... print('LU factorization succeeded for all samples!') + LU factorization succeeded for all samples! + """ + # If get_infos is True, then we don't need to check for errors and vice versa + result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) + if out is not None: + if not isinstance(out, (tuple, list)): + raise TypeError("argument 'out' must be tuple of Tensors, not {}" + .format(type(out).__name__)) + if len(out) - int(get_infos) != 2: + raise TypeError("expected tuple of {} elements but got {}" + .format(2 + int(get_infos), len(out))) + return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out))) + if get_infos: + return result # A_LU, pivots, infos + else: + return result[0], result[1] # A_LU, pivots diff --git a/torch/tensor.py b/torch/tensor.py index bf239b3..4ac4c6b 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -282,10 +282,33 @@ class Tensor(torch._C._TensorBase): def trtrs(self, A, upper=True, transpose=False, unitriangular=False): r"""See :func:`torch.triangular_solve`""" warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be " - "removed in the next release. Please use torch.triangular_solve.", stacklevel=2) + "removed in the next release. Please use torch.triangular_solve instead.", + stacklevel=2) return super(Tensor, self).triangular_solve(A, upper=upper, transpose=transpose, unitriangular=unitriangular) + def btrifact(self, pivot=True): + r"""See :func:`torch.lu`""" + warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in " + "the next release. Please use torch.lu instead.", stacklevel=2) + return torch._lu_with_info(self, pivot=pivot, check_errors=True) + + def btrifact_with_info(self, pivot=True): + r"""See :func:`torch.lu`""" + warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the " + "and will be removed in the next release. Please use torch.lu with the " + "get_infos argument set to True instead.", stacklevel=2) + return torch._lu_with_info(self, pivot=pivot, check_errors=False) + + def lu(self, pivot=True, get_infos=False): + r"""See :func:`torch.lu`""" + # If get_infos is True, then we don't need to check for errors and vice versa + LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) + if get_infos: + return LU, pivots, infos + else: + return LU, pivots + def stft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): r"""See :func:`torch.stft`