From f268370b42bc26d4e5d0fd88eb79b44ae162dec5 Mon Sep 17 00:00:00 2001 From: vishwakftw Date: Tue, 12 Mar 2019 01:42:28 -0700 Subject: [PATCH] torch.btrifact for tensors with greater than 3 dimensions (#14964) Summary: Motivation: - Earlier, `torch.btrifact` could not handle tensors with greater than 3 dimensions. This is because of the check: > AT_CHECK(THTensor_(nDimension)(a) == 3, "expected 3D tensor, got size: ", a->sizes()); What is in this PR?: - Move `btrifact` to ATen - Remove relation to TH/THC. - Handle tensors with more than three dimensions - Tests - Docs modifications: added a note about the non-pivoting variant. [blocked due to old magma-cuda binaries] Pull Request resolved: https://github.com/pytorch/pytorch/pull/14964 Differential Revision: D14405106 Pulled By: soumith fbshipit-source-id: f051f5d6aaa45f85836a2867176c065733563184 --- aten/src/ATen/Declarations.cwrap | 45 ---------- aten/src/ATen/native/BatchLinearAlgebra.cpp | 92 ++++++++++++++++++++ aten/src/ATen/native/LegacyDefinitions.cpp | 16 ---- aten/src/ATen/native/LinearAlgebra.cpp | 7 +- aten/src/ATen/native/LinearAlgebraUtils.h | 17 ++++ aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 95 ++++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 7 ++ aten/src/TH/generic/THTensorLapack.cpp | 77 ----------------- aten/src/TH/generic/THTensorLapack.h | 1 - aten/src/THC/generic/THCTensorMathBlas.cu | 108 ------------------------ aten/src/THC/generic/THCTensorMathBlas.h | 1 - test/test_torch.py | 58 +++++++++---- tools/autograd/gen_python_functions.py | 2 +- torch/_torch_docs.py | 5 ++ 14 files changed, 262 insertions(+), 269 deletions(-) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 9b07dfc..7063a1d 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -2499,51 +2499,6 @@ default: N ]] [[ - name: _th_btrifact - cname: btrifact - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - function - return: argument 0,1 - arguments: - - arg: THTensor* result - output: True - - arg: THIntegerTensor* pivots - output: True - - CONSTANT NULL - - arg: bool pivot - kwarg_only: True - default: "true" - - THTensor* self -]] -[[ - name: _th_btrifact_with_info - cname: btrifact - types: - - floating_point - backends: - - CPU - - CUDA - variants: - - function - return: argument 0,1,2 - arguments: - - arg: THTensor* result - output: True - - arg: THIntegerTensor* pivots - output: True - - arg: THIntegerTensor* info - output: True - - arg: bool pivot - kwarg_only: True - default: "true" - - THTensor* self -]] -[[ name: _th_btrisolve cname: btrisolve types: diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 767368a..31667be 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -391,6 +391,98 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) { +#ifndef USE_LAPACK + AT_ERROR("btrifact: LAPACK library not found in compilation"); +#else + auto self_data = self.data(); + auto self_matrix_stride = matrixStride(self); + auto batch_size = batchCount(self); + + auto pivots_data = pivots.data(); + auto pivots_matrix_stride = pivots.size(-1); + auto infos_data = infos.data(); + + auto n = self.size(-1); + + for (int64_t i = 0; i < batch_size; i++) { + scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; + int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride]; + int* infos_working_ptr = &infos_data[i]; + lapackGetrf(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr); + } +#endif +} + +std::tuple _btrifact_helper_cpu(const Tensor& self, bool pivot) { + AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU"); + AT_CHECK(self.dim() > 2, + "expected tensor with more than 2 dimensions, got size: ", self.sizes(), + " instead"); + squareCheckInputs(self); + auto req_size = self.sizes().vec(); + req_size.pop_back(); + auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt)); + req_size.pop_back(); + auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); + + Tensor self_working_copy; + if (self.numel() == 0) { + self_working_copy = at::empty_like(self); + } else { + self_working_copy = cloneBatchedColumnMajor(self); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{ + apply_btrifact(self_working_copy, pivots_tensor, infos_tensor); + }); + } + return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor); +} + +std::tuple btrifact(const Tensor& self, bool pivot) { + Tensor LU_fact, pivots, infos; + std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot); + batchCheckErrors(infos, "btrifact"); + return std::make_tuple(LU_fact, pivots); +} + +std::tuple btrifact_out( + Tensor& A_LU, + Tensor& pivots, + const Tensor& self, + bool pivot) { + Tensor infos, A_LU_tmp, pivots_tmp; + std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot); + batchCheckErrors(infos, "btrifact"); + A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp); + pivots.resize_as_(pivots_tmp).copy_(pivots_tmp); + return std::tuple(A_LU, pivots); +} + +std::tuple btrifact_with_info( + const Tensor& self, + bool pivot) { + Tensor LU_fact, pivots, infos; + std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot); + return std::make_tuple(LU_fact, pivots, infos); +} + +std::tuple btrifact_with_info_out( + Tensor& A_LU, + Tensor& pivots, + Tensor& info, + const Tensor& self, + bool pivot) { + Tensor info_tmp, A_LU_tmp, pivots_tmp; + std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot); + A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp); + pivots.resize_as_(pivots_tmp).copy_(pivots_tmp); + info.resize_as_(info_tmp).copy_(info_tmp); + return std::tuple(A_LU, pivots, info); +} + template static void apply_triu_tril_single( scalar_t* result, scalar_t* self, bool inplace, diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index 15aeb33..618af2d 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -504,22 +504,6 @@ Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, return at::legacy::th::_th_ormqr(self, input2, input3, left, transpose); } -std::tuple btrifact_out(Tensor & A_LU, Tensor & pivots, const Tensor & self, bool pivot) { - return at::legacy::th::_th_btrifact_out(A_LU, pivots, self, pivot); -} - -std::tuple btrifact(const Tensor & self, bool pivot) { - return at::legacy::th::_th_btrifact(self, pivot); -} - -std::tuple btrifact_with_info_out(Tensor & A_LU, Tensor & pivots, Tensor & info, const Tensor & self, bool pivot) { - return at::legacy::th::_th_btrifact_with_info_out(A_LU, pivots, info, self, pivot); -} - -std::tuple btrifact_with_info(const Tensor & self, bool pivot) { - return at::legacy::th::_th_btrifact_with_info(self, pivot); -} - Tensor & btrisolve_out(Tensor & result, const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) { return at::legacy::th::_th_btrisolve_out(result, self, LU_data, LU_pivots); } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c39fcbb..9ec2165 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -67,7 +67,7 @@ Tensor logdet(const Tensor& self) { if (det.sign().item() <= 0) { return det.log_(); // in order to get proper -inf (det=0) or nan (det<0) } else { - return diag_U.abs().log().sum(); + return diag_U.abs_().log_().sum(); } } @@ -81,11 +81,12 @@ std::tuple slogdet(const Tensor& self) { int info; std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self); if (info > 0) { - det = at::zeros({}, self.type()); + return std::make_tuple(at::zeros({}, self.options()), + at::empty({}, self.options()).fill_(-INFINITY)); } else { det = diag_U.prod().mul_(det_P); + return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } - return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } Tensor pinverse(const Tensor& self, double rcond) { diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 08a0822..4758966 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -111,6 +111,23 @@ static inline void batchCheckErrors(std::vector& infos, const char* nam } /* + * This is an overloaded case of the previous function for a tensor of infos. + */ +static inline void batchCheckErrors(const Tensor& infos, const char* name) { + auto batch_size = infos.numel(); + auto infos_cpu = infos.to(at::kCPU); + auto infos_data = infos_cpu.data(); + for (size_t i = 0; i < batch_size; i++) { + auto info = infos_data[i]; + if (info < 0) { + AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value"); + } else if (info > 0) { + AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U."); + } + } +} + +/* * Given a info int, obtained after a single operation, this function check if the computation * has been successful (info = 0) or not, and report in case of the latter. */ diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 291e0b5..195df2d 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -43,6 +43,13 @@ void magmaGetrfBatched( } template +void magmaGetrfNoPivBatched( + magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, + magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + AT_ERROR("getrf only takes float or double Tensors"); +} + +template void magmaGetriBatched( magma_int_t n, scalar_t** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, scalar_t** dinvA_array, magma_int_t lddia, @@ -125,6 +132,20 @@ void magmaGetrfBatched( } template<> +void magmaGetrfNoPivBatched( + magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, + magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue()); +} + +template<> +void magmaGetrfNoPivBatched( + magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda, + magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { + magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue()); +} + +template<> void magmaGetriBatched( magma_int_t n, double** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, double** dinvA_array, magma_int_t lddia, @@ -274,7 +295,7 @@ std::tuple _gesv_helper_cuda(const Tensor& self, const Tensor& A // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_inverse(Tensor &self, Tensor &self_inv, std::vector& infos) { +static void apply_inverse(Tensor& self, Tensor& self_inv, std::vector& infos) { #ifndef USE_MAGMA AT_ERROR("inverse: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); @@ -461,6 +482,78 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) { } } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) { +#ifndef USE_MAGMA +AT_ERROR("btrifact: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else + auto self_data = self.data(); + auto self_matrix_stride = matrixStride(self); + magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount"); + magma_int_t n = magma_int_cast(self.size(-1), "n"); + + scalar_t** self_array; + ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self); + + // Set up the created arrays + for (int64_t i = 0; i < batch_size; i++) { + self_array[i] = &self_data[i * self_matrix_stride]; + } + + MAGMAQueue magma_queue(self.get_device()); + + // If `pivots` is defined, then we have to compute them. + // We will use the normal getrf function to compute the LU factorization + // and the pivots + if (pivots.defined()) { + auto pivots_data = pivots.data(); + auto pivots_matrix_stride = pivots.size(-1); + magma_int_t** pivots_array; + ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots); + for (int64_t i = 0; i < batch_size; i++) { + pivots_array[i] = &pivots_data[i * pivots_matrix_stride]; + } + + magmaGetrfBatched( + n, n, self_array, n, pivots_array, + infos.data(), batch_size, magma_queue); + } else { + magmaGetrfNoPivBatched( + n, n, self_array, n, infos.data(), + batch_size, magma_queue); + } +#endif +} + +std::tuple _btrifact_helper_cuda(const Tensor& self, bool pivot) { + AT_CHECK(self.dim() > 2, + "expected tensor with more than 2 dimensions, got size: ", self.sizes(), + " instead"); + squareCheckInputs(self); + auto req_size = self.sizes().vec(); + req_size.pop_back(); + Tensor pivots_tensor; + if (pivot) { + pivots_tensor = at::zeros(req_size, self.options().dtype(kInt)); + } + req_size.pop_back(); + auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt)); + + Tensor self_working_copy; + if (self.numel() == 0) { + self_working_copy = at::empty_like(self); + } else { + self_working_copy = cloneBatchedColumnMajor(self); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{ + apply_btrifact(self_working_copy, pivots_tensor, infos_tensor); + }); + } + return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor); +} + template __global__ void triu_tril_kernel( diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ce75d82..2da44a9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3777,6 +3777,13 @@ matches_jit_signature: True variants: method, function +- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: _btrifact_helper_cpu + CUDA: _btrifact_helper_cuda + - func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 1fbea8e..6187166 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -879,83 +879,6 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co c10::raw::intrusive_ptr::decref(work); } -void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a) -{ - AT_CHECK(THTensor_(nDimension)(a) == 3, "expected 3D tensor, got size: ", a->sizes()); - if (!pivot) { - THError("btrifact without pivoting is not implemented on the CPU"); - } - - if (ra_ != a) { - THTensor_(resizeAs)(ra_, a); - at::Tensor ra__wrap = THTensor_wrap(ra_); - at::Tensor a_wrap = THTensor_wrap(a); - at::_copy_same_type_(ra__wrap, a_wrap); - } - - int m = a->size(1); - int n = a->size(2); - if (m != n) { - THError("btrifact is only implemented for square matrices"); - } - int64_t num_batches = THTensor_(size)(a, 0); - THTensor *ra__; - int lda; - - if (ra_->stride(1) == 1) { - // column ordered, what BLAS wants - lda = ra_->stride(2); - ra__ = ra_; - } else { - // not column ordered, need to make it such (requires copy) - THTensor *transp_r_ = THTensor_(newTranspose)(ra_, 1, 2); - ra__ = THTensor_(newClone)(transp_r_); - c10::raw::intrusive_ptr::decref(transp_r_); - THTensor_(transpose)(ra__, NULL, 1, 2); - lda = ra__->stride(2); - } - - THTensor *ai = THTensor_(new)(); - THTensor *rai = THTensor_(new)(); - THIntTensor *rpivoti = THIntTensor_new(); - - int info = 0; - int *info_ptr = &info; - if (rinfo_) { - THIntTensor_resize1d(rinfo_, num_batches); - info_ptr = THIntTensor_data(rinfo_); - } - - THIntTensor_resize2d(rpivots_, num_batches, n); - - int64_t batch = 0; - for (; batch < num_batches; ++batch) { - THTensor_(select)(ai, a, 0, batch); - THTensor_(select)(rai, ra__, 0, batch); - THIntTensor_select(rpivoti, rpivots_, 0, batch); - - THLapack_(getrf)(n, n, rai->data(), lda, - THIntTensor_data(rpivoti), info_ptr); - if (rinfo_) { - info_ptr++; - } else if (info != 0) { - break; - } - } - - c10::raw::intrusive_ptr::decref(ai); - c10::raw::intrusive_ptr::decref(rai); - THIntTensor_free(rpivoti); - - if (ra__ != ra_) { - THTensor_(freeCopyTo)(ra__, ra_); - } - - if (!rinfo_ && info != 0) { - THError("failed to factorize batch element %ld (info == %d)", batch, info); - } -} - void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots) { AT_CHECK(!atf->is_empty() && THTensor_(nDimensionLegacyNoScalars)(atf) == 3, "expected non-empty 3D tensor, got size: ", diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h index 2444307..d337b82 100644 --- a/aten/src/TH/generic/THTensorLapack.h +++ b/aten/src/TH/generic/THTensorLapack.h @@ -17,7 +17,6 @@ TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau); TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans); TH_API void THTensor_(pstrf)(THTensor *ra_, THIntTensor *rpiv_, THTensor*a, const char* uplo, scalar_t tol); -TH_API void THTensor_(btrifact)(THTensor *ra_, THIntTensor *rpivots_, THIntTensor *rinfo_, int pivot, THTensor *a); TH_API void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots); #endif diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu index 7092ddf..53dd565 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ b/aten/src/THC/generic/THCTensorMathBlas.cu @@ -745,114 +745,6 @@ void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCT #endif } -void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a) -{ -#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) - THAssert(THCTensor_(checkGPU)(state, 2, ra_, a)); - THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, a) == 3, 3, "expected 3D tensor"); - THArgCheck(THCTensor_(size)(state, a, 1) == - THCTensor_(size)(state, a, 2), 3, "matrices must be square"); - - if (ra_ != a) { - THCTensor_(resizeAs)(state, ra_, a); - if (ra_->stride(2) == 1) { - THCTensor_(transpose)(state, ra_, NULL, 1, 2); - } - THCTensor_(copy)(state, ra_, a); - } - - - int n = a->size(1); - int lda; - THCTensor *ra__; - - if (ra_->stride(1) == 1) { - // column ordered, what BLAS wants - lda = ra_->stride(2); - ra__ = ra_; - } else { - // not column ordered, need to make it such (requires copy) - THCTensor *transp_r_ = THCTensor_(newTranspose)(state, ra_, 1, 2); - ra__ = THCTensor_(newClone)(state, transp_r_); - THCTensor_(free)(state, transp_r_); - THCTensor_(transpose)(state, ra__, NULL, 1, 2); - lda = ra__->stride(2); - } - - int64_t num_batches = ra__->size(0); - - if (!pivot) { - THCudaIntTensor *t = THCudaIntTensor_new(state); - auto t_aten = THTensor_wrap(t); - at::range_out(t_aten, 1, n, 1); - THCudaIntTensor_unsqueeze1d(state, t, t, 0); - THCudaIntTensor** ptrs = (THCudaIntTensor**) THAlloc(sizeof(THCudaIntTensor*)*num_batches); - for (int64_t i=0; i(THCudaMalloc(state, matrices_size)); - - if (num_batches > 0) { - const int64_t block = 512; - const int64_t grid = (num_batches + block - 1) / block; - createBatchGemmBuffer<<>>( - (const scalar_t**)d_result, THCTensor_(data)(state, ra__), - ra__->stride(0), num_batches); - } - - int *pivots_gpu = NULL; - if (pivot) { - pivots_gpu = THCudaIntTensor_data(state, rpivots_); - } -#ifdef THC_REAL_IS_FLOAT - THCudaBlas_Sgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); -#elif defined(THC_REAL_IS_DOUBLE) - THCudaBlas_Dgetrf(state, n, d_result, lda, pivots_gpu, info_gpu, num_batches); -#endif - - THCudaFree(state, d_result); - - if (ra__ != ra_) { - THCTensor_(freeCopyTo)(state, ra__, ra_); - } - - if (free_rinfo_) { - if(THCTensor_nElement(state, rinfo_) != 0) { - int min = THCudaIntTensor_minall(state, rinfo_); - int max = THCudaIntTensor_maxall(state, rinfo_); - THCudaIntTensor_free(state, rinfo_); - if (min != 0 || max != 0) { - THError("failed to factorize some batch elements (min info == %d, max info == %d)", - min, max); - } - } else { - THCudaIntTensor_free(state, rinfo_); - } - } - -#else - THError("btrifact for CUDA tensors is only supported for floats and doubles"); -#endif -} - - void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots) { diff --git a/aten/src/THC/generic/THCTensorMathBlas.h b/aten/src/THC/generic/THCTensorMathBlas.h index 6f526b8..e112764 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.h +++ b/aten/src/THC/generic/THCTensorMathBlas.h @@ -9,7 +9,6 @@ THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, scalar_t beta, T THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2); THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2); -THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTensor *rpivots_, THCudaIntTensor *rinfo_, int pivot, THCTensor *a); THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *atf, THCudaIntTensor *pivots); diff --git a/test/test_torch.py b/test/test_torch.py index 1a8787d..a50ab5c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1728,24 +1728,41 @@ class _TestTorchMixin(object): @staticmethod def _test_btrifact(self, cast): - a = torch.FloatTensor((((1.3722, -0.9020), - (1.8849, 1.9169)), - ((0.7187, -1.1695), - (-0.0139, 1.3572)), - ((-1.6181, 0.7148), - (1.3728, 0.1319)))) - a = cast(a) - a_LU, pivots = a.btrifact() - - a_LU_, pivots_, info_ = a.btrifact_with_info() - self.assertEqual(a_LU, a_LU_) - self.assertEqual(pivots, pivots_) - self.assertEqual(info_.abs().sum(), 0) - P, a_L, a_U = torch.btriunpack(a_LU, pivots) - a_ = torch.bmm(P, torch.bmm(a_L, a_U)) - self.assertEqual(a_, a) + from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank + + def run_test(matrix_size, batches, cast): + a = cast(fullrank(matrix_size, *batches)) + a_LU_info, pivots_info, info_ = a.btrifact_with_info() + self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size))) + self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,))) + self.assertEqual(info_.size(), torch.Size(batches)) + self.assertEqual(info_.abs().sum(), 0) + a_LU, pivots = a.btrifact() + self.assertEqual(a_LU, a_LU_info) + self.assertEqual(pivots_info, pivots) + if a.is_cuda: + a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False) + self.assertIsNone(nopiv) + self.assertEqual(info_, info_nopiv) + P, L, U = torch.btriunpack(a_LU, pivots) + self.assertEqual(P.matmul(L.matmul(U)), a) + + for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]): + run_test(ms, batch, cast) + + # Info should be positive for rank deficient matrices + a = cast(fullrank(3, 5)) + if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])): + a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency + self.assertGreater(a.btrifact_with_info()[2][0], 0) + + # Error checking, no pivoting variant on CPU + with self.assertRaisesRegex(RuntimeError, + 'btrifact without pivoting is not implemented on the CPU'): + torch.btrifact(torch.empty(1, 2, 2), pivot=False) @skipIfNoLapack + @skipIfRocm def test_btrifact(self): self._test_btrifact(self, lambda t: t) @@ -5389,9 +5406,18 @@ class _TestTorchMixin(object): eye = conv_fn(torch.eye(5)) test_single_det(eye, torch.tensor(1, dtype=eye.dtype), 'identity') + # TODO: Remove when MAGMA 2.5.0 is built for CUDA 8 and CUDA 9.2 + is_cuda_8_92 = False + if torch.cuda.is_available() and torch.version.cuda is not None: + is_cuda_8_92 = any(x in torch.version.cuda for x in ['8.0', '9.2']) + def test(M): assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' M = conv_fn(M) + + if M.is_cuda and is_cuda_8_92: + return + M_det = M.det() ref_M_det = reference_det(M) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index bb48b6a..0066183 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -27,7 +27,7 @@ SKIP_PYTHON_BINDINGS = [ '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_.*', '_thnn_.*', 'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', - '_potrs.*', '_cholesky.*', + '_potrs.*', '_cholesky.*', '_btrifact.*', 'slice', 'randint(_out)?', 'item', '_local_scalar_dense', 'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to', diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 79e704d..30311d1 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5411,6 +5411,11 @@ Batch LU factorization. Returns a tuple containing the LU factorization and pivots. Pivoting is done if :attr:`pivot` is set. +.. note:: + LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting + to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is + available for CUDA. + Arguments: A (Tensor): the tensor to factor pivot (bool, optional): controls whether pivoting is done -- 2.7.4