- THTensor* self
]]
[[
- name: _th_potrs_single
- cname: potrs
- types:
- - Float
- - Double
- backends:
- - CPU
- - CUDA
- variants:
- - function
- return: argument 0
- arguments:
- - arg: THTensor* result
- output: True
- - THTensor* self
- - THTensor* input2
- - arg: bool upper
- if_true: U
- if_false: L
- default: U
-]]
-[[
name: _th_potri
cname: potri
types:
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
- auto A_mat_stride = matrixStride(A);
- auto b_mat_stride = matrixStride(b);
-
- auto batch_size = batchCount(A);
auto n = A.size(-2);
auto nrhs = b.size(-1);
- for (int64_t i = 0; i < batch_size; i++) {
- int info;
- scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
- scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
- lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
- infos[i] = info;
- if (info != 0) {
- return;
+ int info;
+ if (b.dim() == 2) {
+ lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n, b_data, n, &info);
+ infos[0] = info;
+ } else {
+ auto A_mat_stride = matrixStride(A);
+ auto b_mat_stride = matrixStride(b);
+ auto batch_size = batchCount(A);
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
+ scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
+ lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
+ infos[i] = info;
+ if (info != 0) {
+ return;
+ }
}
}
#endif
}
Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
- std::vector<int64_t> infos(batchCount(self), 0);
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
+ std::vector<int64_t> infos(batchCount(self), 0);
AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
});
- batchCheckErrors(infos, "cholesky_solve");
+ if (self.dim() > 2) {
+ batchCheckErrors(infos, "cholesky_solve");
+ } else {
+ singleCheckErrors(infos[0], "cholesky_solve");
+ }
return self_working_copy;
}
// Supports arbitrary batch dimensions for self and A
Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
- if (self.dim() <= 2 && A.dim() <= 2) {
- return at::legacy::th::_th_potrs_single(self, A, upper);
- }
-
+ AT_CHECK(self.dim() >= 2,
+ "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+ AT_CHECK(A.dim() >= 2,
+ "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
return at::_cholesky_solve_helper(self_broadcasted, A_broadcasted, upper);
AT_CHECK(self.dim() == 2 && A.dim() == 2,
"torch.cholesky_solve() with the `out` keyword does not support batching. "
"b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
- return at::legacy::th::_th_potrs_single_out(result, self, A, upper);
+ result = at::_cholesky_solve_helper(self, A, upper);
+ return result;
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
}
template<class scalar_t>
+void magmaCholeskySolve(
+ magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t* dA, magma_int_t ldda,
+ scalar_t* dB, magma_int_t lddb, magma_int_t* info) {
+ AT_ERROR("cholesky_solve only takes float or double Tensors");
+}
+
+template<class scalar_t>
void magmaCholeskySolveBatched(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
scalar_t** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
}
template<>
+void magmaCholeskySolve<double>(
+ magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda,
+ double* dB, magma_int_t lddb, magma_int_t* info) {
+ magma_dpotrs_gpu(uplo, n, nrhs, dA, ldda, dB, lddb, info);
+}
+
+template<>
+void magmaCholeskySolve<float>(
+ magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, float* dA, magma_int_t ldda,
+ float* dB, magma_int_t lddb, magma_int_t* info) {
+ magma_spotrs_gpu(uplo, n, nrhs, dA, ldda, dB, lddb, info);
+}
+
+template<>
void magmaCholeskySolveBatched<double>(
magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
double** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
- auto A_mat_stride = matrixStride(A);
- auto b_mat_stride = matrixStride(b);
-
- magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
- magma_int_t info_tmp;
- scalar_t** A_array;
- scalar_t** b_array;
+ int info_tmp;
+ if (b.dim() == 2) {
+ magmaCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n,
+ b_data, n, &info_tmp);
+ info = info_tmp;
+ } else {
+ auto A_mat_stride = matrixStride(A);
+ auto b_mat_stride = matrixStride(b);
+ magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
- ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
- ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
+ scalar_t** A_array;
+ scalar_t** b_array;
- // Set up the created arrays
- for (int64_t i = 0; i < batch_size; i++) {
- A_array[i] = &A_data[i * A_mat_stride];
- b_array[i] = &b_data[i * b_mat_stride];
- }
+ ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
+ ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
- MAGMAQueue magma_queue(b.get_device());
- magmaCholeskySolveBatched<scalar_t>(
- uplo, n, nrhs, A_array, n, b_array, n,
- info_tmp, batch_size, magma_queue);
+ // Set up the created arrays
+ for (int64_t i = 0; i < batch_size; i++) {
+ A_array[i] = &A_data[i * A_mat_stride];
+ b_array[i] = &b_data[i * b_mat_stride];
+ }
- info = info_tmp;
+ MAGMAQueue magma_queue(b.get_device());
+ magmaCholeskySolveBatched<scalar_t>(
+ uplo, n, nrhs, A_array, n, b_array, n,
+ info_tmp, batch_size, magma_queue);
+
+ info = info_tmp;
+ }
#endif
}
TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
-TH_EXTERNC void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
-TH_EXTERNC void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
TH_EXTERNC void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
#endif
}
-/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
-void THLapack_(potrs)(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info)
-{
-#ifdef USE_LAPACK
-#if defined(TH_REAL_IS_DOUBLE)
- dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
-#else
- spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
-#endif
-#else
- THError("potrs: Lapack library not found in compile time\n");
-#endif
-}
-
/* Cholesky factorization based Matrix Inverse */
void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info)
{
/* Positive Definite matrices */
/* Matrix inverse based on Cholesky factorization */
TH_API void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info);
-/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
-TH_API void THLapack_(potrs)(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
/* Cholesky factorization with complete pivoting. */
TH_API void THLapack_(pstrf)(char uplo, int n, scalar_t *a, int lda, int *piv, int *rank, scalar_t tol, scalar_t *work, int *info);
}
}
-void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
-{
- int free_b = 0;
- if (b == NULL) b = rb_;
-
- THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 2, "A should have 2 dimensions, but has %d",
- THTensor_nDimensionLegacyAll(a));
- THArgCheck(THTensor_nDimensionLegacyAll(b) == 1 || THTensor_nDimensionLegacyAll(b) == 2, 1, "B should have 1 or 2 "
- "dimensions, but has %d", THTensor_nDimensionLegacyAll(b));
- THArgCheck(a->size(0) == a->size(1), 2, "A should be square, but is %ldx%ld",
- a->size(0), a->size(1));
- THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld "
- "rows, B has %ld", a->size(0), b->size(0));
-
- if (THTensor_nDimensionLegacyAll(b) == 1) {
- b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
- b->stride(0), 1, 0);
- free_b = 1;
- }
-
- int n, nrhs, lda, ldb, info;
- THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
- THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS
-
- ra__ = THTensor_(cloneColumnMajor)(NULL, a);
- rb__ = THTensor_(cloneColumnMajor)(rb_, b);
-
- n = (int)ra__->size(0);
- nrhs = (int)rb__->size(1);
- lda = n;
- ldb = n;
-
- THLapack_(potrs)(uplo[0], n, nrhs, ra__->data<scalar_t>(),
- lda, rb__->data<scalar_t>(), ldb, &info);
-
-
- THLapackCheckWithCleanup("Lapack Error in %s : A(%d,%d) is zero, singular A",
- THCleanup(
- c10::raw::intrusive_ptr::decref(ra__);
- c10::raw::intrusive_ptr::decref(rb__);
- if (free_b) c10::raw::intrusive_ptr::decref(b);),
- "potrs", info, info);
-
- if (free_b) c10::raw::intrusive_ptr::decref(b);
- c10::raw::intrusive_ptr::decref(ra__);
- THTensor_(freeCopyTo)(rb__, rb_);
-}
-
void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo)
{
if (a == NULL) a = ra_;
TH_API void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a,
const char *some, const char* compute_uv);
TH_API void THTensor_(getri)(THTensor *ra_, THTensor *a);
-TH_API void THTensor_(potrs)(THTensor *rb_, THTensor *b_, THTensor *a_, const char *uplo);
TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo);
TH_API void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a);
TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a);
#endif
}
-void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo)
-{
-#ifdef USE_MAGMA
- THArgCheck(a->size(0) == a->size(1), 2, "A should be square");
-
- int64_t n = a->size(0);
- int64_t nrhs = b->size(1);
- magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
-
- THCTensor *b_ = THCTensor_(newColumnMajor)(state, rb_, b);
- scalar_t *b_data = THCTensor_(data)(state, b_);
- THCTensor *a_ = THCTensor_(newColumnMajor)(state, a, a);
- scalar_t *a_data = THCTensor_(data)(state, a_);
-
- int info;
-#if defined(THC_REAL_IS_FLOAT)
- magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
-#else
- magma_dpotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
-#endif
-
- // check error value
- if (info < 0)
- THError("MAGMA potrs : Argument %d : illegal value", -info);
-
- THCTensor_(freeCopyTo)(state, b_, rb_);
- THCTensor_(free)(state, a_);
-#else
- THError(NoMagma(potrs));
-#endif
-}
-
void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_)
{
#ifdef USE_MAGMA
const char *some, const char* compute_uv);
THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
-THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);