Rename `btrifact*` to `lu` (#18435)
authorVishwak Srinivasan <cs15btech11043@iith.ac.in>
Fri, 29 Mar 2019 07:27:48 +0000 (00:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 07:34:30 +0000 (00:34 -0700)
Summary:
Changelog:

- Renames `btrifact` and `btrifact_with_info` to `lu`to remain consistent with other factorization methods (`qr` and `svd`).
- Now, we will only have one function and methods named `lu`, which performs `lu` decomposition. This function takes a get_infos kwarg, which when set to True includes a infos tensor in the tuple.
- Rename all tests, fix callsites
- Create a tentative alias for `lu` under the name `btrifact` and `btrifact_with_info`, and add a deprecation warning to not promote usage.
- Add the single batch version for `lu` so that users don't have to unsqueeze and squeeze for a single square matrix (see changes in determinant computation in `LinearAlgebra.cpp`)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18435

Differential Revision: D14680352

Pulled By: soumith

fbshipit-source-id: af58dfc11fa53d9e8e0318c720beaf5502978cd8

20 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
docs/source/tensors.rst
docs/source/torch.rst
test/test_cuda.py
test/test_torch.py
tools/autograd/derivatives.yaml
tools/autograd/gen_python_functions.py
torch/__init__.pyi.in
torch/_tensor_docs.py
torch/_torch_docs.py
torch/functional.py
torch/tensor.py

index 24b64b9..ea3f2b5 100644 (file)
@@ -700,8 +700,6 @@ class CAFFE2_API Tensor {
   std::tuple<Tensor,Tensor> geqrf() const;
   Tensor orgqr(const Tensor & input2) const;
   Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
-  std::tuple<Tensor,Tensor> btrifact(bool pivot=true) const;
-  std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(bool pivot=true) const;
   Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
   Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
   Tensor lgamma() const;
index e153f06..2a05ce7 100644 (file)
@@ -1171,12 +1171,6 @@ inline Tensor Tensor::orgqr(const Tensor & input2) const {
 inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
     return type().ormqr(*this, input2, input3, left, transpose);
 }
-inline std::tuple<Tensor,Tensor> Tensor::btrifact(bool pivot) const {
-    return type().btrifact(*this, pivot);
-}
-inline std::tuple<Tensor,Tensor,Tensor> Tensor::btrifact_with_info(bool pivot) const {
-    return type().btrifact_with_info(*this, pivot);
-}
 inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
     return type().btrisolve(*this, LU_data, LU_pivots);
 }
index 5b7a5d7..dcdd533 100644 (file)
@@ -578,8 +578,6 @@ struct CAFFE2_API Type {
   virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
   virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
   virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
-  virtual std::tuple<Tensor,Tensor> btrifact(const Tensor & self, bool pivot) const = 0;
-  virtual std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(const Tensor & self, bool pivot) const = 0;
   virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
   virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
   virtual Tensor lgamma(const Tensor & self) const = 0;
index b71ffe8..ca3fe63 100644 (file)
@@ -88,6 +88,7 @@ _(aten, _log10) \
 _(aten, _log1p) \
 _(aten, _log2) \
 _(aten, _logspace) \
+_(aten, _lu_with_info) \
 _(aten, _masked_scale) \
 _(aten, _mm) \
 _(aten, _mv) \
@@ -224,8 +225,6 @@ _(aten, bincount) \
 _(aten, blackman_window) \
 _(aten, bmm) \
 _(aten, broadcast_tensors) \
-_(aten, btrifact) \
-_(aten, btrifact_with_info) \
 _(aten, btrisolve) \
 _(aten, cartesian_prod) \
 _(aten, cat) \
index 46f6d47..3507279 100644 (file)
@@ -51,8 +51,8 @@ void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b,
 }
 
 template<class scalar_t>
-void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
-  AT_ERROR("getrf only takes float or double Tensors");
+void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
+  AT_ERROR("lu only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -92,11 +92,11 @@ template<> void lapackGetri<float>(int n, float *a, int lda, int *ipiv, float *w
   sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
 }
 
-template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
+template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
   dgetrf_(&m, &n, a, &lda, ipiv, info);
 }
 
-template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
+template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
   sgetrf_(&m, &n, a, &lda, ipiv, info);
 }
 
@@ -219,7 +219,7 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
   for (int64_t i = 0; i < batch_size; i++) {
     int info;
     scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
-    lapackGetrf<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
+    lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
     infos[i] = info;
     if (info != 0) {
       return;
@@ -406,41 +406,44 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
   return result;
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template<typename scalar_t>
-static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
 #ifndef USE_LAPACK
-  AT_ERROR("btrifact: LAPACK library not found in compilation");
+  AT_ERROR("lu: LAPACK library not found in compilation");
 #else
   auto self_data = self.data<scalar_t>();
-  auto self_matrix_stride = matrixStride(self);
-  auto batch_size = batchCount(self);
-
   auto pivots_data = pivots.data<int>();
-  auto pivots_matrix_stride = pivots.size(-1);
   auto infos_data = infos.data<int>();
 
   auto n = self.size(-1);
 
-  for (int64_t i = 0; i < batch_size; i++) {
-    scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
-    int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
-    int* infos_working_ptr = &infos_data[i];
-    lapackGetrf<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
+  if (self.dim() == 2) {
+    lapackLu<scalar_t>(n, n, self_data, n, pivots_data, infos_data);
+  } else {
+    auto self_matrix_stride = matrixStride(self);
+    auto batch_size = batchCount(self);
+    auto pivots_matrix_stride = pivots.size(-1);
+    for (int64_t i = 0; i < batch_size; i++) {
+      scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+      int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
+      int* infos_working_ptr = &infos_data[i];
+      lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
+    }
   }
 #endif
 }
 
-std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool pivot) {
-  AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU");
-  AT_CHECK(self.dim() > 2,
-           "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
+  AT_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
+  AT_CHECK(self.dim() >= 2,
+           "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
            " instead");
   squareCheckInputs(self);
   auto req_size = self.sizes().vec();
   req_size.pop_back();
-  auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
+  auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
   req_size.pop_back();
   auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
 
@@ -449,55 +452,20 @@ std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool
     self_working_copy = at::empty_like(self);
   } else {
     self_working_copy = cloneBatchedColumnMajor(self);
-    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{
-      apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cpu", [&]{
+      apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
     });
   }
+  if (check_errors) {
+    if (self.dim() == 2) {
+      singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
+    } else {
+      batchCheckErrors(infos_tensor, "lu");
+    }
+  }
   return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
 }
 
-std::tuple<Tensor, Tensor> btrifact(const Tensor& self, bool pivot) {
-  Tensor LU_fact, pivots, infos;
-  std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
-  batchCheckErrors(infos, "btrifact");
-  return std::make_tuple(LU_fact, pivots);
-}
-
-std::tuple<Tensor&, Tensor&> btrifact_out(
-    Tensor& A_LU,
-    Tensor& pivots,
-    const Tensor& self,
-    bool pivot) {
-  Tensor infos, A_LU_tmp, pivots_tmp;
-  std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot);
-  batchCheckErrors(infos, "btrifact");
-  A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
-  pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
-  return std::tuple<Tensor&, Tensor&>(A_LU, pivots);
-}
-
-std::tuple<Tensor, Tensor, Tensor> btrifact_with_info(
-    const Tensor& self,
-    bool pivot) {
-  Tensor LU_fact, pivots, infos;
-  std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
-  return std::make_tuple(LU_fact, pivots, infos);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
-    Tensor& A_LU,
-    Tensor& pivots,
-    Tensor& info,
-    const Tensor& self,
-    bool pivot) {
-  Tensor info_tmp, A_LU_tmp, pivots_tmp;
-  std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot);
-  A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
-  pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
-  info.resize_as_(info_tmp).copy_(info_tmp);
-  return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
-}
-
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t, bool upper>
index 5042acc..5d3c157 100644 (file)
@@ -21,10 +21,8 @@ namespace native {
 // where info helps us identify singular matrices.
 static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor& self) {
   Tensor p, lu, info;
-  std::tie(lu, p, info) = self.unsqueeze(0).btrifact_with_info();
-  p.squeeze_(0);
-  lu.squeeze_(0);
-  int int_info = info.squeeze_().item<int32_t>();
+  std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
+  int int_info = info.item<int32_t>();
   AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
   auto n = self.size(0);
   auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);
index fc23d3c..3e470f5 100644 (file)
@@ -109,7 +109,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
            " but each b matrix is ", self.size(-2), " by ", self.size(-1));
 }
 
-// Validates input shapes for operations on batches of square matrices (inverse, cholesky)
+// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu)
 static inline void squareCheckInputs(const Tensor& self) {
   AT_CHECK(self.size(-1) == self.size(-2),
            "A must be batches of square matrices, "
index f401995..1742981 100644 (file)
@@ -35,18 +35,32 @@ void magmaSolveBatched(
 }
 
 template<class scalar_t>
-void magmaGetrfBatched(
+void magmaLu(
+    magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
+    magma_int_t* ipiv, magma_int_t* info) {
+  AT_ERROR("lu only takes float or double Tensors");
+}
+
+template<class scalar_t>
+void magmaLuBatched(
     magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
     magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
-  AT_ERROR("getrf only takes float or double Tensors");
+  AT_ERROR("lu only takes float or double Tensors");
+}
+
+template<class scalar_t>
+void magmaLuNoPiv(
+    magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
+    magma_int_t* info) {
+  AT_ERROR("lu only takes float or double Tensors");
 }
 
 template<class scalar_t>
-void magmaGetrfNoPivBatched(
+void magmaLuNoPivBatched(
     magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
     magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
-  AT_ERROR("getrf only takes float or double Tensors");
+  AT_ERROR("lu only takes float or double Tensors");
 }
 
 template<class scalar_t>
@@ -131,7 +145,21 @@ void magmaSolveBatched<float>(
 }
 
 template<>
-void magmaGetrfBatched<double>(
+void magmaLu<double>(
+    magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
+    magma_int_t* ipiv, magma_int_t* info) {
+  magma_dgetrf_gpu(m, n, dA, ldda, ipiv, info);
+}
+
+template<>
+void magmaLu<float>(
+    magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
+    magma_int_t* ipiv, magma_int_t* info) {
+  magma_sgetrf_gpu(m, n, dA, ldda, ipiv, info);
+}
+
+template<>
+void magmaLuBatched<double>(
     magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
     magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
@@ -139,7 +167,7 @@ void magmaGetrfBatched<double>(
 }
 
 template<>
-void magmaGetrfBatched<float>(
+void magmaLuBatched<float>(
     magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
     magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
@@ -147,14 +175,28 @@ void magmaGetrfBatched<float>(
 }
 
 template<>
-void magmaGetrfNoPivBatched<double>(
+void magmaLuNoPiv<double>(
+    magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
+    magma_int_t* info) {
+  magma_dgetrf_nopiv_gpu(m, n, dA, ldda, info);
+}
+
+template<>
+void magmaLuNoPiv<float>(
+    magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
+    magma_int_t* info) {
+  magma_sgetrf_nopiv_gpu(m, n, dA, ldda, info);
+}
+
+template<>
+void magmaLuNoPivBatched<double>(
     magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
     magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
   magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
 }
 
 template<>
-void magmaGetrfNoPivBatched<float>(
+void magmaLuNoPivBatched<float>(
     magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
     magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
   magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
@@ -373,7 +415,7 @@ AT_ERROR("inverse: MAGMA library not found in "
   }
 
   MAGMAQueue magma_queue(self.get_device());
-  magmaGetrfBatched<scalar_t>(
+  magmaLuBatched<scalar_t>(
     n, n, self_array, n, ipiv_array, info_array,
     batch_size, magma_queue);
 
@@ -527,75 +569,96 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
   }
 }
 
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
 #ifndef USE_MAGMA
-AT_ERROR("btrifact: MAGMA library not found in "
+AT_ERROR("lu: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
 #else
   auto self_data = self.data<scalar_t>();
-  auto self_matrix_stride = matrixStride(self);
-  magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
   magma_int_t n = magma_int_cast(self.size(-1), "n");
 
-  scalar_t** self_array;
-  ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
-
-  // Set up the created arrays
-  for (int64_t i = 0; i < batch_size; i++) {
-    self_array[i] = &self_data[i * self_matrix_stride];
-  }
+  if (self.dim() == 2) {
+    // If `pivots` is defined, then we have to compute them.
+    // We will use the normal getrf function to compute the LU factorization
+    // and the pivots
+    // We create temporary tensors on the CPU, because tensors on the GPU
+    // cause segfault when passed to magmaLu and magmaLuNoPiv. The data is later
+    // copied to the appropriate tensors.
+    Tensor info_tmp = at::zeros({}, at::kInt);
+    if (get_pivots) {
+      Tensor piv_tmp = at::empty({n}, at::kInt);
+      magmaLu<scalar_t>(
+        n, n, self_data, n, piv_tmp.data<magma_int_t>(), info_tmp.data<magma_int_t>());
+      pivots.copy_(piv_tmp);
+    } else {
+      magmaLuNoPiv<scalar_t>(n, n, self_data, n, info_tmp.data<magma_int_t>());
+    }
+    infos.copy_(info_tmp);
+  } else {
+    auto self_matrix_stride = matrixStride(self);
+    magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
 
-  MAGMAQueue magma_queue(self.get_device());
+    scalar_t** self_array;
+    ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
 
-  // If `pivots` is defined, then we have to compute them.
-  // We will use the normal getrf function to compute the LU factorization
-  // and the pivots
-  if (pivots.defined()) {
-    auto pivots_data = pivots.data<magma_int_t>();
-    auto pivots_matrix_stride = pivots.size(-1);
-    magma_int_t** pivots_array;
-    ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
+    // Set up the created arrays
     for (int64_t i = 0; i < batch_size; i++) {
-      pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
+      self_array[i] = &self_data[i * self_matrix_stride];
     }
 
-    magmaGetrfBatched<scalar_t>(
-      n, n, self_array, n, pivots_array,
-      infos.data<magma_int_t>(), batch_size, magma_queue);
-  } else {
-    magmaGetrfNoPivBatched<scalar_t>(
-      n, n, self_array, n, infos.data<magma_int_t>(),
-      batch_size, magma_queue);
+    MAGMAQueue magma_queue(self.get_device());
+
+    // Same comment as in the case of single matrix above.
+    if (get_pivots) {
+      auto pivots_data = pivots.data<magma_int_t>();
+      auto pivots_matrix_stride = pivots.size(-1);
+      magma_int_t** pivots_array;
+      ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
+      for (int64_t i = 0; i < batch_size; i++) {
+        pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
+      }
+      magmaLuBatched<scalar_t>(
+        n, n, self_array, n, pivots_array,
+        infos.data<magma_int_t>(), batch_size, magma_queue);
+    } else {
+      magmaLuNoPivBatched<scalar_t>(
+        n, n, self_array, n, infos.data<magma_int_t>(),
+        batch_size, magma_queue);
+    }
   }
 #endif
 }
 
-std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, bool pivot) {
-  AT_CHECK(self.dim() > 2,
-           "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
+  AT_CHECK(self.dim() >= 2,
+           "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
            " instead");
   squareCheckInputs(self);
   auto req_size = self.sizes().vec();
   req_size.pop_back();
-  Tensor pivots_tensor;
-  if (pivot) {
-    pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
-  }
+  Tensor pivots_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
   req_size.pop_back();
-  auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
+  auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
 
   Tensor self_working_copy;
   if (self.numel() == 0) {
     self_working_copy = at::empty_like(self);
   } else {
     self_working_copy = cloneBatchedColumnMajor(self);
-    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{
-      apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{
+      apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
     });
   }
+  if (check_errors) {
+    if (self.dim() == 2) {
+      singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
+    } else {
+      batchCheckErrors(infos_tensor, "lu");
+    }
+  }
   return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
 }
 
index 152a2ae..e4200ab 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
-- func: btrifact(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots) -> (Tensor(a!), Tensor(b!))
-  matches_jit_signature: True
-
-- func: btrifact(Tensor self, *, bool pivot=True) -> (Tensor, Tensor)
-  matches_jit_signature: True
-  variants: method, function
-
-- func: btrifact_with_info(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!), Tensor(b!), Tensor(c!))
-  matches_jit_signature: True
-
-- func: btrifact_with_info(Tensor self, *, bool pivot=True) -> (Tensor, Tensor, Tensor)
-  matches_jit_signature: True
-  variants: method, function
-
-- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor)
+- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
   matches_jit_signature: True
   variants: function
   dispatch:
-    CPU: _btrifact_helper_cpu
-    CUDA: _btrifact_helper_cuda
+    CPU: _lu_with_info_cpu
+    CUDA: _lu_with_info_cuda
 
 - func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
index 2bba064..b9ad547 100644 (file)
@@ -306,6 +306,7 @@ view of a storage and defines numeric operations on it.
    .. automethod:: long
    .. automethod:: lt
    .. automethod:: lt_
+   .. automethod:: lu
    .. automethod:: map_
    .. automethod:: masked_scatter_
    .. automethod:: masked_scatter
index 1013128..30a09dc 100644 (file)
@@ -315,6 +315,7 @@ BLAS and LAPACK Operations
 .. autofunction:: det
 .. autofunction:: logdet
 .. autofunction:: slogdet
+.. autofunction:: lu
 .. autofunction:: matmul
 .. autofunction:: matrix_power
 .. autofunction:: matrix_rank
index 72a3157..e919cf9 100644 (file)
@@ -2361,8 +2361,8 @@ class TestCuda(TestCase):
 
     @skipIfRocm
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
-    def test_btrifact(self):
-        _TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
+    def test_lu(self):
+        _TestTorchMixin._test_lu(self, lambda t: t.cuda())
 
     @skipIfRocm
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
index 93da20c..bd0ac1a 100644 (file)
@@ -1746,44 +1746,43 @@ class _TestTorchMixin(object):
             _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
 
     @staticmethod
-    def _test_btrifact(self, cast):
+    def _test_lu(self, cast):
         from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
 
         def run_test(matrix_size, batches, cast):
             a = cast(fullrank(matrix_size, *batches))
-            a_LU_info, pivots_info, info_ = a.btrifact_with_info()
+            a_LU_info, pivots_info, info_ = a.lu(get_infos=True)
             self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
             self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
             self.assertEqual(info_.size(), torch.Size(batches))
             self.assertEqual(info_.abs().sum(), 0)
-            a_LU, pivots = a.btrifact()
+            a_LU, pivots = a.lu()
             self.assertEqual(a_LU, a_LU_info)
             self.assertEqual(pivots_info, pivots)
             if a.is_cuda:
-                a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
-                self.assertIsNone(nopiv)
+                a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True)
+                self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32)))
                 self.assertEqual(info_, info_nopiv)
             P, L, U = torch.btriunpack(a_LU, pivots)
             self.assertEqual(P.matmul(L.matmul(U)), a)
 
-        for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
+        for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]):
             run_test(ms, batch, cast)
 
         # Info should be positive for rank deficient matrices
-        a = cast(fullrank(3, 5))
+        a = cast(torch.ones(5, 3, 3))
         if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
-            a[0, 1] = 2 * a[0, 0]  # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
-            self.assertGreater(a.btrifact_with_info()[2][0], 0)
+            self.assertGreater(a.lu(get_infos=True)[2][0], 0)
 
         # Error checking, no pivoting variant on CPU
         with self.assertRaisesRegex(RuntimeError,
-                                    'btrifact without pivoting is not implemented on the CPU'):
-            torch.btrifact(torch.empty(1, 2, 2), pivot=False)
+                                    'lu without pivoting is not implemented on the CPU'):
+            torch.lu(torch.empty(1, 2, 2), pivot=False)
 
     @skipIfNoLapack
     @skipIfRocm
-    def test_btrifact(self):
-        self._test_btrifact(self, lambda t: t)
+    def test_lu(self):
+        self._test_lu(self, lambda t: t)
 
     @staticmethod
     def _test_btrisolve(self, cast):
@@ -1797,7 +1796,7 @@ class _TestTorchMixin(object):
                                (-1.56, 4.00),
                                (9.81, -4.09)))
         a, b = cast(a), cast(b)
-        LU_data, pivots, info = a.btrifact_with_info()
+        LU_data, pivots, info = a.lu(get_infos=True)
         self.assertEqual(info.abs().sum(), 0)
         x = torch.btrisolve(b, LU_data, pivots)
         b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
@@ -1811,12 +1810,11 @@ class _TestTorchMixin(object):
     def _test_btriunpack(self, cast):
         def run_test(shape, cast):
             a = cast(torch.randn(*shape))
-            a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
-            a_lu = a_lu.reshape_as(a)
-            p = p.reshape(a.shape[:-1])
+            a_lu, p = torch.lu(a)
             p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
             self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
 
+        run_test((3, 3), cast)
         run_test((5, 3, 3), cast)
         run_test((7, 3, 5, 5), cast)
         run_test((7, 5, 3, 3, 3), cast)
@@ -4743,11 +4741,14 @@ class _TestTorchMixin(object):
             A = torch.randn(3, 3, device=A_device)
             err_str = "Expected b and A to be on the same device"
             with self.assertRaisesRegex(RuntimeError, err_str):
-                torch.gesv(b, A)
+                torch.solve(b, A)
 
             with self.assertRaisesRegex(RuntimeError, err_str):
                 torch.cholesky_solve(b, A)
 
+            with self.assertRaisesRegex(RuntimeError, err_str):
+                torch.triangular_solve(b, A)
+
     @skipIfNoLapack
     def test_qr(self):
 
@@ -7965,12 +7966,12 @@ class _TestTorchMixin(object):
             self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
 
             if torch._C.has_lapack:
-                # btrifact
-                A_LU, pivots = fn(torch.btrifact, (0, 5, 5))
+                # lu
+                A_LU, pivots = fn(torch.lu, (0, 5, 5))
                 self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
-                A_LU, pivots = fn(torch.btrifact, (0, 0, 0))
+                A_LU, pivots = fn(torch.lu, (0, 0, 0))
                 self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
-                A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
+                A_LU, pivots = fn(torch.lu, (2, 0, 0))
                 self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
 
     @skipIfRocm
index 4a59185..425a5b5 100644 (file)
   self: grad.bmm(mat2.transpose(1, 2))
   mat2: self.transpose(1, 2).bmm(grad)
 
-- name: btrifact(Tensor self, bool pivot)
-  self: not_implemented("btrifact")
-
-- name: btrifact_with_info(Tensor self, bool pivot)
-  self: not_implemented("btrifact_with_info")
-
 - name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
   self: not_implemented("btrisolve")
 
   self: zeros_like(self)
   other: zeros_like(other)
 
+- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
+  self: not_implemented("lu_with_info")
+
 - name: masked_fill_(Tensor self, Tensor mask, Scalar value)
   self: grad.clone().masked_fill_(mask, 0)
 
index 80f7495..7201f77 100644 (file)
@@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
     'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
-    '_cholesky.*', '_btrifact.*', '_triangular_solve.*',
+    '_cholesky.*', '_triangular_solve.*',
     'slice', 'randint(_out)?',
     'item', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
index 0d842f5..141fbe0 100644 (file)
@@ -78,6 +78,7 @@ class Tensor:
              center=True, pad_mode='reflect', normalized=False, onesided=True): ...
     def split(self, split_size, dim=0): ...
     def unique(self, sorted=True, return_inverse=False, dim=None): ...
+    def lu(self, pivot=True, get_infos=False): ...
 
 ${function_hints}
 
index 289c3f3..bc90d94 100644 (file)
@@ -486,20 +486,6 @@ bmm(batch2) -> Tensor
 See :func:`torch.bmm`
 """)
 
-add_docstr_all('btrifact',
-               r"""
-btrifact(pivot=True) -> (Tensor, Tensor)
-
-See :func:`torch.btrifact`
-""")
-
-add_docstr_all('btrifact_with_info',
-               r"""
-btrifact_with_info(pivot=True) -> (Tensor, Tensor, Tensor)
-
-See :func:`torch.btrifact_with_info`
-""")
-
 add_docstr_all('btrisolve',
                r"""
 btrisolve(LU_data, LU_pivots) -> Tensor
@@ -1019,13 +1005,6 @@ ger(vec2) -> Tensor
 See :func:`torch.ger`
 """)
 
-add_docstr_all('solve',
-               r"""
-solve(A) -> Tensor, Tensor
-
-See :func:`torch.solve`
-""")
-
 add_docstr_all('indices',
                r"""
 indices() -> Tensor
@@ -2228,6 +2207,13 @@ Example::
 
 """)
 
+add_docstr_all('solve',
+               r"""
+solve(A) -> Tensor, Tensor
+
+See :func:`torch.solve`
+""")
+
 add_docstr_all('sort',
                r"""
 sort(dim=-1, descending=False) -> (Tensor, LongTensor)
index 3249126..1a9262b 100644 (file)
@@ -5516,71 +5516,6 @@ Example::
             [ 0.,  0.,  0.]])
 """.format(**factory_like_common_args))
 
-add_docstr(torch.btrifact,
-           r"""
-btrifact(A, pivot=True) -> (Tensor, IntTensor)
-
-Batch LU factorization.
-
-Returns a tuple containing the LU factorization and pivots. Pivoting is done if
-:attr:`pivot` is set.
-
-.. note::
-    LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting
-    to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is
-    available for CUDA.
-
-Arguments:
-    A (Tensor): the tensor to factor
-    pivot (bool, optional): controls whether pivoting is done
-
-Returns:
-    A tuple containing factorization and pivots.
-
-Example::
-
-    >>> A = torch.randn(2, 3, 3)
-    >>> A_LU, pivots = torch.btrifact(A)
-    >>> A_LU
-    tensor([[[ 1.3506,  2.5558, -0.0816],
-             [ 0.1684,  1.1551,  0.1940],
-             [ 0.1193,  0.6189, -0.5497]],
-
-            [[ 0.4526,  1.2526, -0.3285],
-             [-0.7988,  0.7175, -0.9701],
-             [ 0.2634, -0.9255, -0.3459]]])
-
-    >>> pivots
-    tensor([[ 3,  3,  3],
-            [ 3,  3,  3]], dtype=torch.int32)
-""")
-
-add_docstr(torch.btrifact_with_info,
-           r"""
-btrifact_with_info(A, pivot=True) -> (Tensor, IntTensor, IntTensor)
-
-Batch LU factorization with additional error information.
-
-This is a version of :meth:`torch.btrifact` that always creates an info
-`IntTensor`, and returns it as the third return value.
-
-Arguments:
-    A (Tensor): the tensor to factor
-    pivot (bool, optional): controls whether pivoting is done
-
-Returns:
-    A tuple containing factorization, pivots, and an `IntTensor` where non-zero
-    values indicate whether factorization for each minibatch sample succeeds.
-
-Example::
-
-    >>> A = torch.randn(2, 3, 3)
-    >>> A_LU, pivots, info = A.btrifact_with_info()
-    >>> if info.nonzero().size(0) == 0:
-    >>>   print('LU factorization succeeded for all samples!')
-    LU factorization succeeded for all samples!
-""")
-
 add_docstr(torch.btrisolve,
            r"""
 btrisolve(b, LU_data, LU_pivots) -> Tensor
@@ -5591,14 +5526,14 @@ Returns the LU solve of the linear system :math:`Ax = b`.
 
 Arguments:
     b (Tensor): the RHS tensor
-    LU_data (Tensor): the pivoted LU factorization of A from :meth:`btrifact`.
+    LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
     LU_pivots (IntTensor): the pivots of the LU factorization
 
 Example::
 
     >>> A = torch.randn(2, 3, 3)
     >>> b = torch.randn(2, 3)
-    >>> A_LU = torch.btrifact(A)
+    >>> A_LU = torch.lu(A)
     >>> x = torch.btrisolve(b, *A_LU)
     >>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
     tensor(1.00000e-07 *
index 580227c..2fcb46a 100644 (file)
@@ -6,23 +6,26 @@ import warnings
 
 __all__ = [
     'btriunpack',
+    'broadcast_tensors',
+    'btrifact',
+    'btrifact_with_info',
+    'cartesian_prod',
     'chain_matmul',
     'einsum',
-    'broadcast_tensors',
+    'gesv',
     'isfinite',
     'isinf',
+    'lu',
     'norm',
     'meshgrid',
     'potrf',
     'pstrf',
     'potrs',
-    'gesv',
     'split',
     'stft',
     'tensordot',
     'trtrs',
     'unique',
-    'cartesian_prod',
 ]
 
 
@@ -81,7 +84,7 @@ def split(tensor, split_size_or_sections, dim=0):
 
 
 def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
-    r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
+    r"""Unpacks the data and pivots from a LU factorization of a tensor.
 
     Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
 
@@ -94,7 +97,7 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
     Example::
 
         >>> A = torch.randn(2, 3, 3)
-        >>> A_LU, pivots = A.btrifact()
+        >>> A_LU, pivots = A.lu()
         >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots)
         >>>
         >>> # can recover A from factorization
@@ -111,13 +114,20 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
         L = U = None
 
     if unpack_pivots:
-        P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
-        LU_pivots = LU_pivots - 1
-        for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
+        LU_pivots_zero_idx = LU_pivots - 1
+        if LU_data.dim() > 2:
+            P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
+            for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
+                final_order = list(range(sz))
+                for k, j in enumerate(LU_pivots_zero_idx[idx]):
+                    final_order[k], final_order[j] = final_order[j], final_order[k]
+                P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+        else:
+            P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype)
             final_order = list(range(sz))
-            for k, j in enumerate(LU_pivots[idx]):
+            for k, j, in enumerate(LU_pivots_zero_idx):
                 final_order[k], final_order[j] = final_order[j], final_order[k]
-            P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
     else:
         P = None
 
@@ -751,6 +761,8 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
     In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
     with the default keyword arguments.
 
+    For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.
+
     .. warning::
         :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
         removed in the next release. Please use :func:`torch.triangular_solve` instead.
@@ -758,3 +770,112 @@ def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
     warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
                   "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
     return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
+
+
+def btrifact(A, pivot=True, out=None):
+    r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+    Pivoting is done if :attr:`pivot` is set.
+
+    For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`.
+
+    .. warning::
+        :func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be
+        removed in the next release. Please use :func:`torch.lu` instead.
+    """
+    warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be "
+                  "removed in the next release. Please use torch.lu instead.", stacklevel=2)
+    return lu(A, pivot=pivot, get_infos=False, out=out)
+
+
+def btrifact_with_info(A, pivot=True, out=None):
+    r"""Performs LU factorization and returns additional status information along with the LU
+    factorization and pivots.
+
+    For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`.
+
+    .. warning::
+        :func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will
+        be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos`
+        argument set to ``True`` instead.
+    """
+    warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be "
+                  "removed in the next release. Please use torch.lu with the get_infos argument "
+                  "set to True instead.",
+                  stacklevel=2)
+    return lu(A, pivot=pivot, get_infos=True, out=out)
+
+
+def lu(A, pivot=True, get_infos=False, out=None):
+    r"""Computes the LU factorization of a square matrix or batches of square matrices
+    :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+    Pivoting is done if :attr:`pivot` is set to ``True``.
+
+    .. note::
+        The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
+        then the returned pivots is a tensor filled with zeros of the appropriate size.
+
+    .. note::
+        LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
+        to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
+        available for CUDA.
+
+    .. note::
+        This function does not check if the factorization was successful or not if
+        :attr:`get_infos` is ``True`` since the status of the factorization is present in the
+        third element of the return tuple.
+
+    Arguments:
+        A (Tensor): the tensor to factor of size :math:`(*, m, m)`
+        pivot (bool, optional): controls whether pivoting is done. Default: ``True``
+        get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
+                                    Default: ``False``
+        out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
+                               then the elements in the tuple are Tensor, IntTensor,
+                               and IntTensor. If :attr:`get_infos` is ``False``, then the
+                               elements in the tuple are Tensor, IntTensor. Default: ``None``
+
+    Returns:
+        (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
+
+            - **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`
+
+            - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`
+
+            - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
+              size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
+              each minibatch has succeeded or failed
+
+    Example::
+
+        >>> A = torch.randn(2, 3, 3)
+        >>> A_LU, pivots = torch.lu(A)
+        >>> A_LU
+        tensor([[[ 1.3506,  2.5558, -0.0816],
+                 [ 0.1684,  1.1551,  0.1940],
+                 [ 0.1193,  0.6189, -0.5497]],
+
+                [[ 0.4526,  1.2526, -0.3285],
+                 [-0.7988,  0.7175, -0.9701],
+                 [ 0.2634, -0.9255, -0.3459]]])
+        >>> pivots
+        tensor([[ 3,  3,  3],
+                [ 3,  3,  3]], dtype=torch.int32)
+        >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
+        >>> if info.nonzero().size(0) == 0:
+        ...   print('LU factorization succeeded for all samples!')
+        LU factorization succeeded for all samples!
+    """
+    # If get_infos is True, then we don't need to check for errors and vice versa
+    result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
+    if out is not None:
+        if not isinstance(out, (tuple, list)):
+            raise TypeError("argument 'out' must be tuple of Tensors, not {}"
+                            .format(type(out).__name__))
+        if len(out) - int(get_infos) != 2:
+            raise TypeError("expected tuple of {} elements but got {}"
+                            .format(2 + int(get_infos), len(out)))
+        return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
+    if get_infos:
+        return result  # A_LU, pivots, infos
+    else:
+        return result[0], result[1]  # A_LU, pivots
index bf239b3..4ac4c6b 100644 (file)
@@ -282,10 +282,33 @@ class Tensor(torch._C._TensorBase):
     def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
         r"""See :func:`torch.triangular_solve`"""
         warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
-                      "removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
+                      "removed in the next release. Please use torch.triangular_solve instead.",
+                      stacklevel=2)
         return super(Tensor, self).triangular_solve(A, upper=upper,
                                                     transpose=transpose, unitriangular=unitriangular)
 
+    def btrifact(self, pivot=True):
+        r"""See :func:`torch.lu`"""
+        warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
+                      "the next release. Please use torch.lu instead.", stacklevel=2)
+        return torch._lu_with_info(self, pivot=pivot, check_errors=True)
+
+    def btrifact_with_info(self, pivot=True):
+        r"""See :func:`torch.lu`"""
+        warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
+                      "and will be removed in the next release. Please use torch.lu with the "
+                      "get_infos argument set to True instead.", stacklevel=2)
+        return torch._lu_with_info(self, pivot=pivot, check_errors=False)
+
+    def lu(self, pivot=True, get_infos=False):
+        r"""See :func:`torch.lu`"""
+        # If get_infos is True, then we don't need to check for errors and vice versa
+        LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
+        if get_infos:
+            return LU, pivots, infos
+        else:
+            return LU, pivots
+
     def stft(self, n_fft, hop_length=None, win_length=None, window=None,
              center=True, pad_mode='reflect', normalized=False, onesided=True):
         r"""See :func:`torch.stft`