torch.btrifact for tensors with greater than 3 dimensions (#14964)
authorvishwakftw <cs15btech11043@iith.ac.in>
Tue, 12 Mar 2019 08:42:28 +0000 (01:42 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 08:46:07 +0000 (01:46 -0700)
Summary:
Motivation:
- Earlier, `torch.btrifact` could not handle tensors with greater than 3 dimensions. This is because of the check:
>   AT_CHECK(THTensor_(nDimension)(a) == 3, "expected 3D tensor, got size: ", a->sizes());

What is in this PR?:
- Move `btrifact` to ATen
- Remove relation to TH/THC.
- Handle tensors with more than three dimensions
- Tests
- Docs modifications: added a note about the non-pivoting variant.

[blocked due to old magma-cuda binaries]
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14964

Differential Revision: D14405106

Pulled By: soumith

fbshipit-source-id: f051f5d6aaa45f85836a2867176c065733563184

14 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
aten/src/TH/generic/THTensorLapack.cpp
aten/src/TH/generic/THTensorLapack.h
aten/src/THC/generic/THCTensorMathBlas.cu
aten/src/THC/generic/THCTensorMathBlas.h
test/test_torch.py
tools/autograd/gen_python_functions.py
torch/_torch_docs.py

index 9b07dfc..7063a1d 100644 (file)
       default: N
 ]]
 [[
-  name: _th_btrifact
-  cname: btrifact
-  types:
-    - floating_point
-  backends:
-    - CPU
-    - CUDA
-  variants:
-    - function
-  return: argument 0,1
-  arguments:
-    - arg: THTensor* result
-      output: True
-    - arg: THIntegerTensor* pivots
-      output: True
-    - CONSTANT NULL
-    - arg: bool pivot
-      kwarg_only: True
-      default: "true"
-    - THTensor* self
-]]
-[[
-  name: _th_btrifact_with_info
-  cname: btrifact
-  types:
-    - floating_point
-  backends:
-    - CPU
-    - CUDA
-  variants:
-    - function
-  return: argument 0,1,2
-  arguments:
-    - arg: THTensor* result
-      output: True
-    - arg: THIntegerTensor* pivots
-      output: True
-    - arg: THIntegerTensor* info
-      output: True
-    - arg: bool pivot
-      kwarg_only: True
-      default: "true"
-    - THTensor* self
-]]
-[[
   name: _th_btrisolve
   cname: btrisolve
   types:
index 767368a..31667be 100644 (file)
@@ -391,6 +391,98 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
   return result;
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<typename scalar_t>
+static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+#ifndef USE_LAPACK
+  AT_ERROR("btrifact: LAPACK library not found in compilation");
+#else
+  auto self_data = self.data<scalar_t>();
+  auto self_matrix_stride = matrixStride(self);
+  auto batch_size = batchCount(self);
+
+  auto pivots_data = pivots.data<int>();
+  auto pivots_matrix_stride = pivots.size(-1);
+  auto infos_data = infos.data<int>();
+
+  auto n = self.size(-1);
+
+  for (int64_t i = 0; i < batch_size; i++) {
+    scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+    int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
+    int* infos_working_ptr = &infos_data[i];
+    lapackGetrf<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
+  }
+#endif
+}
+
+std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool pivot) {
+  AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU");
+  AT_CHECK(self.dim() > 2,
+           "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+           " instead");
+  squareCheckInputs(self);
+  auto req_size = self.sizes().vec();
+  req_size.pop_back();
+  auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
+  req_size.pop_back();
+  auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
+
+  Tensor self_working_copy;
+  if (self.numel() == 0) {
+    self_working_copy = at::empty_like(self);
+  } else {
+    self_working_copy = cloneBatchedColumnMajor(self);
+    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{
+      apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+    });
+  }
+  return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
+}
+
+std::tuple<Tensor, Tensor> btrifact(const Tensor& self, bool pivot) {
+  Tensor LU_fact, pivots, infos;
+  std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
+  batchCheckErrors(infos, "btrifact");
+  return std::make_tuple(LU_fact, pivots);
+}
+
+std::tuple<Tensor&, Tensor&> btrifact_out(
+    Tensor& A_LU,
+    Tensor& pivots,
+    const Tensor& self,
+    bool pivot) {
+  Tensor infos, A_LU_tmp, pivots_tmp;
+  std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot);
+  batchCheckErrors(infos, "btrifact");
+  A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
+  pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
+  return std::tuple<Tensor&, Tensor&>(A_LU, pivots);
+}
+
+std::tuple<Tensor, Tensor, Tensor> btrifact_with_info(
+    const Tensor& self,
+    bool pivot) {
+  Tensor LU_fact, pivots, infos;
+  std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
+  return std::make_tuple(LU_fact, pivots, infos);
+}
+
+std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
+    Tensor& A_LU,
+    Tensor& pivots,
+    Tensor& info,
+    const Tensor& self,
+    bool pivot) {
+  Tensor info_tmp, A_LU_tmp, pivots_tmp;
+  std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot);
+  A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
+  pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
+  info.resize_as_(info_tmp).copy_(info_tmp);
+  return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
+}
+
 template <typename scalar_t, bool upper>
 static void apply_triu_tril_single(
     scalar_t* result, scalar_t* self, bool inplace,
index 15aeb33..618af2d 100644 (file)
@@ -504,22 +504,6 @@ Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3,
   return at::legacy::th::_th_ormqr(self, input2, input3, left, transpose);
 }
 
-std::tuple<Tensor &,Tensor &> btrifact_out(Tensor & A_LU, Tensor & pivots, const Tensor & self, bool pivot) {
-  return at::legacy::th::_th_btrifact_out(A_LU, pivots, self, pivot);
-}
-
-std::tuple<Tensor,Tensor> btrifact(const Tensor & self, bool pivot) {
-  return at::legacy::th::_th_btrifact(self, pivot);
-}
-
-std::tuple<Tensor &,Tensor &,Tensor &> btrifact_with_info_out(Tensor & A_LU, Tensor & pivots, Tensor & info, const Tensor & self, bool pivot) {
-  return at::legacy::th::_th_btrifact_with_info_out(A_LU, pivots, info, self, pivot);
-}
-
-std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(const Tensor & self, bool pivot) {
-  return at::legacy::th::_th_btrifact_with_info(self, pivot);
-}
-
 Tensor & btrisolve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) {
   return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots);
 }
index c39fcbb..9ec2165 100644 (file)
@@ -67,7 +67,7 @@ Tensor logdet(const Tensor& self) {
   if (det.sign().item<double>() <= 0) {
     return det.log_();  // in order to get proper -inf (det=0) or nan (det<0)
   } else {
-    return diag_U.abs().log().sum();
+    return diag_U.abs_().log_().sum();
   }
 }
 
@@ -81,11 +81,12 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
   int info;
   std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
   if (info > 0) {
-    det = at::zeros({}, self.type());
+    return std::make_tuple(at::zeros({}, self.options()),
+                           at::empty({}, self.options()).fill_(-INFINITY));
   } else {
     det = diag_U.prod().mul_(det_P);
+    return std::make_tuple(det.sign(), diag_U.abs_().log_().sum());
   }
-  return std::make_tuple(det.sign(), diag_U.abs_().log_().sum());
 }
 
 Tensor pinverse(const Tensor& self, double rcond) {
index 08a0822..4758966 100644 (file)
@@ -111,6 +111,23 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
 }
 
 /*
+ * This is an overloaded case of the previous function for a tensor of infos.
+ */
+static inline void batchCheckErrors(const Tensor& infos, const char* name) {
+  auto batch_size = infos.numel();
+  auto infos_cpu = infos.to(at::kCPU);
+  auto infos_data = infos_cpu.data<int>();
+  for (size_t i = 0; i < batch_size; i++) {
+    auto info = infos_data[i];
+    if (info < 0) {
+      AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value");
+    } else if (info > 0) {
+      AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U.");
+    }
+  }
+}
+
+/*
  * Given a info int, obtained after a single operation, this function check if the computation
  * has been successful (info = 0) or not, and report in case of the latter.
  */
index 291e0b5..195df2d 100644 (file)
@@ -43,6 +43,13 @@ void magmaGetrfBatched(
 }
 
 template<class scalar_t>
+void magmaGetrfNoPivBatched(
+    magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+  AT_ERROR("getrf only takes float or double Tensors");
+}
+
+template<class scalar_t>
 void magmaGetriBatched(
     magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
     magma_int_t** ipiv_array, scalar_t** dinvA_array, magma_int_t lddia,
@@ -125,6 +132,20 @@ void magmaGetrfBatched<float>(
 }
 
 template<>
+void magmaGetrfNoPivBatched<double>(
+    magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+  magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
+}
+
+template<>
+void magmaGetrfNoPivBatched<float>(
+    magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+  magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
+}
+
+template<>
 void magmaGetriBatched<double>(
     magma_int_t n, double** dA_array, magma_int_t ldda,
     magma_int_t** ipiv_array, double** dinvA_array, magma_int_t lddia,
@@ -274,7 +295,7 @@ std::tuple<Tensor, Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_inverse(Tensor &self, Tensor &self_inv, std::vector<int64_t>& infos) {
+static void apply_inverse(Tensor& self, Tensor& self_inv, std::vector<int64_t>& infos) {
 #ifndef USE_MAGMA
 AT_ERROR("inverse: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
@@ -461,6 +482,78 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
   }
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <typename scalar_t>
+static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+#ifndef USE_MAGMA
+AT_ERROR("btrifact: MAGMA library not found in "
+    "compilation. Please rebuild with MAGMA.");
+#else
+  auto self_data = self.data<scalar_t>();
+  auto self_matrix_stride = matrixStride(self);
+  magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
+  magma_int_t n = magma_int_cast(self.size(-1), "n");
+
+  scalar_t** self_array;
+  ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
+
+  // Set up the created arrays
+  for (int64_t i = 0; i < batch_size; i++) {
+    self_array[i] = &self_data[i * self_matrix_stride];
+  }
+
+  MAGMAQueue magma_queue(self.get_device());
+
+  // If `pivots` is defined, then we have to compute them.
+  // We will use the normal getrf function to compute the LU factorization
+  // and the pivots
+  if (pivots.defined()) {
+    auto pivots_data = pivots.data<magma_int_t>();
+    auto pivots_matrix_stride = pivots.size(-1);
+    magma_int_t** pivots_array;
+    ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
+    for (int64_t i = 0; i < batch_size; i++) {
+      pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
+    }
+
+    magmaGetrfBatched<scalar_t>(
+      n, n, self_array, n, pivots_array,
+      infos.data<magma_int_t>(), batch_size, magma_queue);
+  } else {
+    magmaGetrfNoPivBatched<scalar_t>(
+      n, n, self_array, n, infos.data<magma_int_t>(),
+      batch_size, magma_queue);
+  }
+#endif
+}
+
+std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, bool pivot) {
+  AT_CHECK(self.dim() > 2,
+           "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+           " instead");
+  squareCheckInputs(self);
+  auto req_size = self.sizes().vec();
+  req_size.pop_back();
+  Tensor pivots_tensor;
+  if (pivot) {
+    pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
+  }
+  req_size.pop_back();
+  auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
+
+  Tensor self_working_copy;
+  if (self.numel() == 0) {
+    self_working_copy = at::empty_like(self);
+  } else {
+    self_working_copy = cloneBatchedColumnMajor(self);
+    AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{
+      apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+    });
+  }
+  return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
+}
+
 template <typename scalar_t, bool upper>
 __global__
 void triu_tril_kernel(
index ce75d82..2da44a9 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
+- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor)
+  matches_jit_signature: True
+  variants: function
+  dispatch:
+    CPU: _btrifact_helper_cpu
+    CUDA: _btrifact_helper_cuda
+
 - func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
index 1fbea8e..6187166 100644 (file)
@@ -879,83 +879,6 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co
   c10::raw::intrusive_ptr::decref(work);
 }
 
-void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a)
-{
-  AT_CHECK(THTensor_(nDimension)(a) == 3, "expected 3D tensor, got size: ", a->sizes());
-  if (!pivot) {
-    THError("btrifact without pivoting is not implemented on the CPU");
-  }
-
-  if (ra_ != a) {
-    THTensor_(resizeAs)(ra_, a);
-    at::Tensor ra__wrap = THTensor_wrap(ra_);
-    at::Tensor a_wrap = THTensor_wrap(a);
-    at::_copy_same_type_(ra__wrap, a_wrap);
-  }
-
-  int m = a->size(1);
-  int n = a->size(2);
-  if (m != n) {
-    THError("btrifact is only implemented for square matrices");
-  }
-  int64_t num_batches = THTensor_(size)(a, 0);
-  THTensor *ra__;
-  int lda;
-
-  if (ra_->stride(1) == 1) {
-    // column ordered, what BLAS wants
-    lda = ra_->stride(2);
-    ra__ = ra_;
-  } else {
-    // not column ordered, need to make it such (requires copy)
-    THTensor *transp_r_ = THTensor_(newTranspose)(ra_, 1, 2);
-    ra__ = THTensor_(newClone)(transp_r_);
-    c10::raw::intrusive_ptr::decref(transp_r_);
-    THTensor_(transpose)(ra__, NULL, 1, 2);
-    lda = ra__->stride(2);
-  }
-
-  THTensor *ai = THTensor_(new)();
-  THTensor *rai = THTensor_(new)();
-  THIntTensor *rpivoti = THIntTensor_new();
-
-  int info = 0;
-  int *info_ptr = &info;
-  if (rinfo_) {
-    THIntTensor_resize1d(rinfo_, num_batches);
-    info_ptr = THIntTensor_data(rinfo_);
-  }
-
-  THIntTensor_resize2d(rpivots_, num_batches, n);
-
-  int64_t batch = 0;
-  for (; batch < num_batches; ++batch) {
-    THTensor_(select)(ai, a, 0, batch);
-    THTensor_(select)(rai, ra__, 0, batch);
-    THIntTensor_select(rpivoti, rpivots_, 0, batch);
-
-    THLapack_(getrf)(n, n, rai->data<scalar_t>(), lda,
-                     THIntTensor_data(rpivoti), info_ptr);
-    if (rinfo_) {
-      info_ptr++;
-    } else if (info != 0) {
-      break;
-    }
-  }
-
-  c10::raw::intrusive_ptr::decref(ai);
-  c10::raw::intrusive_ptr::decref(rai);
-  THIntTensor_free(rpivoti);
-
-  if (ra__ != ra_) {
-    THTensor_(freeCopyTo)(ra__, ra_);
-  }
-
-  if (!rinfo_ && info != 0) {
-    THError("failed to factorize batch element %ld (info == %d)", batch, info);
-  }
-}
-
 void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
 {
   AT_CHECK(!atf->is_empty() && THTensor_(nDimensionLegacyNoScalars)(atf) == 3, "expected non-empty 3D tensor, got size: ",
index 2444307..d337b82 100644 (file)
@@ -17,7 +17,6 @@ TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau);
 TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans);
 TH_API void THTensor_(pstrf)(THTensor *ra_, THIntTensor *rpiv_, THTensor*a, const char* uplo, scalar_t tol);
 
-TH_API void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a);
 TH_API void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots);
 
 #endif
index 7092ddf..53dd565 100644 (file)
@@ -745,114 +745,6 @@ void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCT
 #endif
 }
 
-void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a)
-{
-#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
-  THAssert(THCTensor_(checkGPU)(state, 2, ra_, a));
-  THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, a) == 3, 3, "expected 3D tensor");
-  THArgCheck(THCTensor_(size)(state, a, 1) ==
-             THCTensor_(size)(state, a, 2), 3, "matrices must be square");
-
-  if (ra_ != a) {
-    THCTensor_(resizeAs)(state, ra_, a);
-    if (ra_->stride(2) == 1) {
-      THCTensor_(transpose)(state, ra_, NULL, 1, 2);
-    }
-    THCTensor_(copy)(state, ra_, a);
-  }
-
-
-  int n = a->size(1);
-  int lda;
-  THCTensor *ra__;
-
-  if (ra_->stride(1) == 1) {
-    // column ordered, what BLAS wants
-    lda = ra_->stride(2);
-    ra__ = ra_;
-  } else {
-    // not column ordered, need to make it such (requires copy)
-    THCTensor *transp_r_ = THCTensor_(newTranspose)(state, ra_, 1, 2);
-    ra__ = THCTensor_(newClone)(state, transp_r_);
-    THCTensor_(free)(state, transp_r_);
-    THCTensor_(transpose)(state, ra__, NULL, 1, 2);
-    lda = ra__->stride(2);
-  }
-
-  int64_t num_batches = ra__->size(0);
-
-  if (!pivot) {
-    THCudaIntTensor *t = THCudaIntTensor_new(state);
-    auto t_aten = THTensor_wrap(t);
-    at::range_out(t_aten, 1, n, 1);
-    THCudaIntTensor_unsqueeze1d(state, t, t, 0);
-    THCudaIntTensor** ptrs = (THCudaIntTensor**) THAlloc(sizeof(THCudaIntTensor*)*num_batches);
-    for (int64_t i=0; i<num_batches; i++) {
-      ptrs[i] = t;
-    }
-    THCudaIntTensor_catArray(state, rpivots_, ptrs, num_batches, 0);
-    THCudaIntTensor_free(state, t);
-    THFree(ptrs);
-  } else {
-    THCudaIntTensor_resize2d(state, rpivots_, num_batches, n);
-  }
-
-  bool free_rinfo_ = !rinfo_;
-  if (rinfo_ == NULL) rinfo_ = THCudaIntTensor_new(state);
-  THCudaIntTensor_resize1d(state, rinfo_, num_batches);
-  // THCudaBlas_Sgetrf,THCudaBlas_Dgetrf will NOT write out the info if the dimensionality is 0;
-  // we could check this explicitly, but this seems safer.
-  THCudaIntTensor_zero(state, rinfo_);
-  int *info_gpu = THCudaIntTensor_data(state, rinfo_);
-
-  // Copy pointers to device.
-  size_t matrices_size = num_batches * sizeof(scalar_t*);
-  auto d_result = static_cast<scalar_t**>(THCudaMalloc(state, matrices_size));
-
-  if (num_batches > 0) {
-    const int64_t block = 512;
-    const int64_t grid = (num_batches + block - 1) / block;
-    createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-      (const scalar_t**)d_result, THCTensor_(data)(state, ra__),
-      ra__->stride(0), num_batches);
-  }
-
-  int *pivots_gpu = NULL;
-  if (pivot) {
-    pivots_gpu = THCudaIntTensor_data(state, rpivots_);
-  }
-#ifdef THC_REAL_IS_FLOAT
-  THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches);
-#elif defined(THC_REAL_IS_DOUBLE)
-  THCudaBlas_Dgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches);
-#endif
-
-  THCudaFree(state, d_result);
-
-  if (ra__ != ra_) {
-    THCTensor_(freeCopyTo)(state, ra__, ra_);
-  }
-
-  if (free_rinfo_) {
-    if(THCTensor_nElement(state, rinfo_) != 0) {
-      int min = THCudaIntTensor_minall(state, rinfo_);
-      int max = THCudaIntTensor_maxall(state, rinfo_);
-      THCudaIntTensor_free(state, rinfo_);
-      if (min != 0 || max != 0) {
-        THError("failed to factorize some batch elements (min info == %d, max info == %d)",
-                min, max);
-      }
-    } else {
-      THCudaIntTensor_free(state, rinfo_);
-    }
-  }
-
-#else
-  THError("btrifact for CUDA tensors is only supported for floats and doubles");
-#endif
-}
-
-
 void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b,
                            THCTensor *atf, THCudaIntTensor *pivots)
 {
index 6f526b8..e112764 100644 (file)
@@ -9,7 +9,6 @@ THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, scalar_t beta, T
 THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2);
 THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2);
 
-THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a);
 THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots);
 
 
index 1a8787d..a50ab5c 100644 (file)
@@ -1728,24 +1728,41 @@ class _TestTorchMixin(object):
 
     @staticmethod
     def _test_btrifact(self, cast):
-        a = torch.FloatTensor((((1.3722, -0.9020),
-                                (1.8849, 1.9169)),
-                               ((0.7187, -1.1695),
-                                (-0.0139, 1.3572)),
-                               ((-1.6181, 0.7148),
-                                (1.3728, 0.1319))))
-        a = cast(a)
-        a_LU, pivots = a.btrifact()
-
-        a_LU_, pivots_, info_ = a.btrifact_with_info()
-        self.assertEqual(a_LU, a_LU_)
-        self.assertEqual(pivots, pivots_)
-        self.assertEqual(info_.abs().sum(), 0)
-        P, a_L, a_U = torch.btriunpack(a_LU, pivots)
-        a_ = torch.bmm(P, torch.bmm(a_L, a_U))
-        self.assertEqual(a_, a)
+        from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
+
+        def run_test(matrix_size, batches, cast):
+            a = cast(fullrank(matrix_size, *batches))
+            a_LU_info, pivots_info, info_ = a.btrifact_with_info()
+            self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
+            self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
+            self.assertEqual(info_.size(), torch.Size(batches))
+            self.assertEqual(info_.abs().sum(), 0)
+            a_LU, pivots = a.btrifact()
+            self.assertEqual(a_LU, a_LU_info)
+            self.assertEqual(pivots_info, pivots)
+            if a.is_cuda:
+                a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
+                self.assertIsNone(nopiv)
+                self.assertEqual(info_, info_nopiv)
+            P, L, U = torch.btriunpack(a_LU, pivots)
+            self.assertEqual(P.matmul(L.matmul(U)), a)
+
+        for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
+            run_test(ms, batch, cast)
+
+        # Info should be positive for rank deficient matrices
+        a = cast(fullrank(3, 5))
+        if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
+            a[0, 1] = 2 * a[0, 0]  # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
+            self.assertGreater(a.btrifact_with_info()[2][0], 0)
+
+        # Error checking, no pivoting variant on CPU
+        with self.assertRaisesRegex(RuntimeError,
+                                    'btrifact without pivoting is not implemented on the CPU'):
+            torch.btrifact(torch.empty(1, 2, 2), pivot=False)
 
     @skipIfNoLapack
+    @skipIfRocm
     def test_btrifact(self):
         self._test_btrifact(self, lambda t: t)
 
@@ -5389,9 +5406,18 @@ class _TestTorchMixin(object):
         eye = conv_fn(torch.eye(5))
         test_single_det(eye, torch.tensor(1, dtype=eye.dtype), 'identity')
 
+        # TODO: Remove when MAGMA 2.5.0 is built for CUDA 8 and CUDA 9.2
+        is_cuda_8_92 = False
+        if torch.cuda.is_available() and torch.version.cuda is not None:
+            is_cuda_8_92 = any(x in torch.version.cuda for x in ['8.0', '9.2'])
+
         def test(M):
             assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
             M = conv_fn(M)
+
+            if M.is_cuda and is_cuda_8_92:
+                return
+
             M_det = M.det()
             ref_M_det = reference_det(M)
 
index bb48b6a..0066183 100644 (file)
@@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
     'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
-    '_potrs.*', '_cholesky.*',
+    '_potrs.*', '_cholesky.*', '_btrifact.*',
     'slice', 'randint(_out)?',
     'item', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
index 79e704d..30311d1 100644 (file)
@@ -5411,6 +5411,11 @@ Batch LU factorization.
 Returns a tuple containing the LU factorization and pivots. Pivoting is done if
 :attr:`pivot` is set.
 
+.. note::
+    LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting
+    to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is
+    available for CUDA.
+
 Arguments:
     A (Tensor): the tensor to factor
     pivot (bool, optional): controls whether pivoting is done