From a519217ee72bdef12b5d133b7e941d92d6ccd3cf Mon Sep 17 00:00:00 2001 From: Vishwak Srinivasan Date: Wed, 20 Mar 2019 11:06:56 -0700 Subject: [PATCH] Add batched version of trtrs (#18025) Summary: - Remove single batch TH/THC implementations - Remove `_batch_trtrs_lower` from `multivariate_normal` - Add tests for batched behavior - Modify trtrs_backward to accommodate for batched case - Modify docs In a future PR, this will be renamed to `triangular_solve`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18025 Differential Revision: D14523004 Pulled By: ifedan fbshipit-source-id: 11c6a967d107f969b60e5a5c73ce6bb8099ebbe1 --- aten/src/ATen/Declarations.cwrap | 32 ------ aten/src/ATen/native/BatchLinearAlgebra.cpp | 102 ++++++++++++++++- aten/src/ATen/native/LegacyDefinitions.cpp | 8 -- aten/src/ATen/native/cuda/BatchLinearAlgebra.cu | 124 +++++++++++++++++++-- aten/src/ATen/native/native_functions.yaml | 7 ++ aten/src/TH/generic/THLapack.cpp | 17 --- aten/src/TH/generic/THLapack.h | 2 - aten/src/TH/generic/THTensorLapack.cpp | 50 --------- aten/src/TH/generic/THTensorLapack.h | 1 - aten/src/THC/generic/THCTensorMathMagma.cu | 37 ------ aten/src/THC/generic/THCTensorMathMagma.h | 2 - test/test_autograd.py | 12 +- test/test_cuda.py | 8 ++ test/test_torch.py | 87 +++++++++++++++ tools/autograd/templates/Functions.cpp | 2 +- torch/_torch_docs.py | 28 +++-- torch/distributions/kl.py | 13 +-- torch/distributions/lowrank_multivariate_normal.py | 5 +- torch/distributions/multivariate_normal.py | 13 +-- 19 files changed, 347 insertions(+), 203 deletions(-) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 7063a1d..c9cec39 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -2243,38 +2243,6 @@ - THTensor* A ]] [[ - name: _th_trtrs - cname: trtrs - types: - - Float - - Double - backends: - - CPU - - CUDA - variants: - - function - return: argument 0,1 - arguments: - - arg: THTensor* res1 - output: True - - arg: THTensor* res2 - output: True - - THTensor* self - - THTensor* A - - arg: bool upper - if_true: U - if_false: L - default: U - - arg: bool transpose - if_true: T - if_false: N - default: N - - arg: bool unitriangular - if_true: U - if_false: N - default: N -]] -[[ name: _th_symeig cname: syev types: diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index ea8b1d4..562e154 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -19,9 +19,11 @@ extern "C" void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info); extern "C" void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info); -// inverse +// getrf extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info); extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); + +// getri extern "C" void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info); extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info); @@ -32,6 +34,10 @@ extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float // potrf extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info); extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info); + +// trtrs +extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); +extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); #endif namespace at { @@ -64,6 +70,11 @@ void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) { AT_ERROR("cholesky only takes float or double Tensors"); } +template +void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) { + AT_ERROR("trtrs only takes float or double Tensors"); +} + #ifdef USE_LAPACK template<> void lapackSolve(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) { dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info); @@ -104,6 +115,14 @@ template<> void lapackCholesky(char uplo, int n, double *a, int lda, int template<> void lapackCholesky(char uplo, int n, float *a, int lda, int *info) { spotrf_(&uplo, &n, a, &lda, info); } + +template<> void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) { + dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); +} + +template<> void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) { + strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); +} #endif // Below of the definitions of the functions operating on a batch that are going to be dispatched @@ -317,7 +336,9 @@ Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A, AT_CHECK(self.dim() == 2 && A.dim() == 2, "torch.cholesky_solve() with the `out` keyword does not support batching. " "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); - result = at::_cholesky_solve_helper(self, A, upper); + Tensor result_tmp; + result_tmp = at::_cholesky_solve_helper(self, A, upper); + result.resize_as_(result_tmp).copy_(result_tmp); return result; } @@ -480,6 +501,8 @@ std::tuple btrifact_with_info_out( return std::tuple(A_LU, pivots, info); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + template static void apply_triu_tril_single( scalar_t* result, scalar_t* self, bool inplace, @@ -618,4 +641,79 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) { return result; } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector& infos) { +#ifndef USE_LAPACK + AT_ERROR("trtrs: LAPACK library not found in compilation"); +#else + char uplo = upper ? 'U' : 'L'; + char trans = transpose ? 'T' : 'N'; + char diag = unitriangular ? 'U' : 'N'; + + auto A_data = A.data(); + auto b_data = b.data(); + auto n = A.size(-2); + auto nrhs = b.size(-1); + + int info; + if (b.dim() == 2) { + lapackTrtrs(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info); + infos[0] = info; + } else { + auto A_mat_stride = matrixStride(A); + auto b_mat_stride = matrixStride(b); + auto batch_size = batchCount(A); + for (int64_t i = 0; i < batch_size; i++) { + scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; + scalar_t* b_working_ptr = &b_data[i * b_mat_stride]; + lapackTrtrs(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info); + infos[i] = info; + if (info != 0) { + return; + } + } + } +#endif +} + +std::tuple _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + std::vector infos(batchCount(self), 0); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{ + apply_trtrs(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos); + }); + if (self.dim() > 2) { + batchCheckErrors(infos, "trtrs_cpu"); + } else { + singleCheckErrors(infos[0], "trtrs_cpu"); + } + return std::tuple(self_working_copy, A_working_copy); +} + +// Supports arbitrary batch dimensions for self and A +std::tuple trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { + AT_CHECK(self.dim() >= 2, + "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); + AT_CHECK(A.dim() >= 2, + "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead"); + Tensor self_broadcasted, A_broadcasted; + std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A); + return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular); +} + +std::tuple trtrs_out(Tensor& result, Tensor& clone_A, + const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { + AT_CHECK(self.dim() == 2 && A.dim() == 2, + "torch.trtrs() with the `out` keyword does not support batching. " + "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2."); + Tensor result_tmp, clone_A_tmp; + std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular); + result.resize_as_(result_tmp).copy_(result_tmp); + clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp); + return std::tuple(result, clone_A); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index 618af2d..d8b3daa 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -424,14 +424,6 @@ std::tuple gels(const Tensor & self, const Tensor & A) { return at::legacy::th::_th_gels(self, A); } -std::tuple trtrs_out(Tensor & X, Tensor & M, const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) { - return at::legacy::th::_th_trtrs_out(X, M, self, A, upper, transpose, unitriangular); -} - -std::tuple trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) { - return at::legacy::th::_th_trtrs(self, A, upper, transpose, unitriangular); -} - std::tuple symeig_out(Tensor & e, Tensor & V, const Tensor & self, bool eigenvectors, bool upper) { return at::legacy::th::_th_symeig_out(e, V, self, eigenvectors, upper); } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index dcadc87..812f1ea 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -85,20 +85,19 @@ void magmaCholeskyBatched( AT_ERROR("cholesky only takes float or double Tensors"); } -template<> -void magmaSolveBatched( - magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, - magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb, - magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { - magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue()); +template +void magmaTrsm( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) { + AT_ERROR("trtrs only takes float or double Tensors"); } -template<> -void magmaSolveBatched( - magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda, - magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb, - magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { - magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue()); +template +void magmaTrsmBatched( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + AT_ERROR("trtrs only takes float or double Tensors"); } template<> @@ -116,6 +115,22 @@ void magmaSolve( } template<> +void magmaSolveBatched( + magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_dgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue()); +} + +template<> +void magmaSolveBatched( + magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda, + magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb, + magma_int_t* dinfo_array, magma_int_t batch_count, const MAGMAQueue& magma_queue) { + magma_sgesv_batched(n, nrhs, dA_array, ldda, dipiv_array, dB_array, lddb, dinfo_array, batch_count, magma_queue.get_queue()); +} + +template<> void magmaGetrfBatched( magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize, @@ -216,6 +231,36 @@ void magmaCholeskyBatched( magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magma_spotrf_batched(uplo, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue()); } + +template<> +void magmaTrsm( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) { + magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); +} + +template<> +void magmaTrsm( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) { + magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb); +} + +template<> +void magmaTrsmBatched( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); +} + +template<> +void magmaTrsmBatched( + magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize, + const MAGMAQueue& magma_queue) { + magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); +} #endif #define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \ @@ -554,6 +599,8 @@ std::tuple _btrifact_helper_cuda(const Tensor& self, boo return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + template __global__ void triu_tril_kernel( @@ -637,6 +684,59 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) { return triu_tril_cuda_template(result, self_c, k, "triu"); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) { +#ifndef USE_MAGMA +AT_ERROR("cholesky_solve: MAGMA library not found in " + "compilation. Please rebuild with MAGMA."); +#else + magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; + magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans; + magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit; + + auto A_data = A.data(); + auto b_data = b.data(); + magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)"); + magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + + if (b.dim() == 2) { + magmaTrsm(uplo, trans, diag, n, nrhs, A_data, n, b_data, n); + } else { + auto A_mat_stride = matrixStride(A); + auto b_mat_stride = matrixStride(b); + magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); + + scalar_t** A_array; + scalar_t** b_array; + + ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b); + ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b); + + // Set up the created arrays + for (int64_t i = 0; i < batch_size; i++) { + A_array[i] = &A_data[i * A_mat_stride]; + b_array[i] = &b_data[i * b_mat_stride]; + } + + MAGMAQueue magma_queue(b.get_device()); + magmaTrsmBatched( + uplo, trans, diag, n, nrhs, A_array, n, + b_array, n, batch_size, magma_queue); + } +#endif +} + +std::tuple _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) { + auto self_working_copy = cloneBatchedColumnMajor(self); + auto A_working_copy = cloneBatchedColumnMajor(A); + AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{ + apply_trsm(self_working_copy, A_working_copy, upper, transpose, unitriangular); + }); + return std::tuple(self_working_copy, A_working_copy); +} + }} // namespace at::native #undef ALLOCATE_ARRAY diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f4146f8..337a696 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3713,6 +3713,13 @@ matches_jit_signature: True variants: method, function +- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor) + matches_jit_signature: True + variants: function + dispatch: + CPU: _trtrs_helper_cpu + CUDA: _trsm_helper_cuda + - func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) matches_jit_signature: True diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index 28a4f29..1c81ed2 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -3,8 +3,6 @@ #else -TH_EXTERNC void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); -TH_EXTERNC void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info); TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info); TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info); @@ -31,21 +29,6 @@ TH_EXTERNC void spstrf_(char *uplo, int *n, float *a, int *lda, int *piv, int *r TH_EXTERNC void dpstrf_(char *uplo, int *n, double *a, int *lda, int *piv, int *rank, double *tol, double *work, int *info); -/* Solve a triangular system of the form A * X = B or A^T * X = B */ -void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int* info) -{ -#ifdef USE_LAPACK -#if defined(TH_REAL_IS_DOUBLE) - dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); -#else - strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); -#endif -#else - THError("trtrs : Lapack library not found in compile time\n"); -#endif - return; -} - /* Solve overdetermined or underdetermined real linear systems involving an M-by-N matrix A, or its transpose, using a QR or LQ factorization of A */ void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info) diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h index 5c65140..0557834 100644 --- a/aten/src/TH/generic/THLapack.h +++ b/aten/src/TH/generic/THLapack.h @@ -2,8 +2,6 @@ #define TH_GENERIC_FILE "TH/generic/THLapack.h" #else -/* Solve a triangular system of the form A * X = B or A^T * X = B */ -TH_API void THLapack_(trtrs)(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int* info); /* ||AX-B|| */ TH_API void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info); /* Eigenvals */ diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 6187166..3ace8a6 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -106,56 +106,6 @@ static THTensor *THTensor_(cloneColumnMajor)(THTensor *self, THTensor *src) return THTensor_(cloneColumnMajorNrows)(self, src, src->size(0)); } -void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a, - const char *uplo, const char *trans, const char *diag) -{ - int free_b = 0; - if (a == NULL) a = ra_; - if (b == NULL) b = rb_; - THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 2, "A should have 2 dimensions, but has %d", - THTensor_nDimensionLegacyAll(a)); - THArgCheck(THTensor_nDimensionLegacyAll(b) == 1 || THTensor_nDimensionLegacyAll(b) == 2, 1, "B should have 1 or 2 " - "dimensions, but has %d", THTensor_nDimensionLegacyAll(b)); - THArgCheck(a->size(0) == a->size(1), 2, "A should be square, but is %ldx%ld", - a->size(0), a->size(1)); - THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld " - "rows, B has %ld", a->size(0), b->size(0)); - - if (THTensor_nDimensionLegacyAll(b) == 1) { - b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0), - b->stride(0), 1, 0); - free_b = 1; - } - - int n, nrhs, lda, ldb, info; - THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS - THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS - - ra__ = THTensor_(cloneColumnMajor)(ra_, a); - rb__ = THTensor_(cloneColumnMajor)(rb_, b); - - n = (int)ra__->size(0); - nrhs = (int)rb__->size(1); - lda = n; - ldb = n; - - THLapack_(trtrs)(uplo[0], trans[0], diag[0], n, nrhs, - ra__->data(), lda, - rb__->data(), ldb, &info); - - - THLapackCheckWithCleanup("Lapack Error in %s : A(%d,%d) is zero, singular A", - THCleanup( - c10::raw::intrusive_ptr::decref(ra__); - c10::raw::intrusive_ptr::decref(rb__); - if (free_b) c10::raw::intrusive_ptr::decref(b);), - "trtrs", info, info); - - THTensor_(freeCopyTo)(ra__, ra_); - THTensor_(freeCopyTo)(rb__, rb_); - if (free_b) c10::raw::intrusive_ptr::decref(b); -} - void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) { int free_b = 0; diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h index d337b82..4c693a8 100644 --- a/aten/src/TH/generic/THTensorLapack.h +++ b/aten/src/TH/generic/THTensorLapack.h @@ -2,7 +2,6 @@ #define TH_GENERIC_FILE "TH/generic/THTensorLapack.h" #else -TH_API void THTensor_(trtrs)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_, const char *uplo, const char *trans, const char *diag); TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); TH_API void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobz, const char *uplo); TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr); diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index 418bb7e..198495bb 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -59,43 +59,6 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T return self; } -void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_, - const char *uplo, const char *trans, const char *diag) -{ -#ifdef USE_MAGMA - THArgCheck(!a_->is_empty() && a_->dim() == 2, 1, "A should be (non-empty) 2 dimensional"); - THArgCheck(!b_->is_empty() && b_->dim() == 2, 2, "b should be (non-empty) 2 dimensional"); - THArgCheck(a_->size(0) == a_->size(1), 1, "A should be square"); - THArgCheck(b_->size(0) == a_->size(0), 2, "A,b size incompatible"); - - magma_side_t sz = MagmaLeft; - magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; - magma_trans_t ts = trans[0] == 'N' ? MagmaNoTrans : MagmaTrans; - magma_diag_t dg = diag[0] == 'U' ? MagmaUnit : MagmaNonUnit; - - scalar_t alpha = 1; - - int64_t n = a_->size(0); - int64_t nrhs = b_->size(1); - - THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_); - THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_); - scalar_t *a_data = THCTensor_(data)(state, a); - scalar_t *b_data = THCTensor_(data)(state, b); - -#if defined(THC_REAL_IS_FLOAT) - magma_strsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n); -#else - magma_dtrsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n); -#endif - - THCTensor_(freeCopyTo)(state, a, ra_); - THCTensor_(freeCopyTo)(state, b, rb_); -#else - THError(NoMagma(trtrs)); -#endif -} - void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) { #ifdef USE_MAGMA diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h index e3870f2..f388f68 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ b/aten/src/THC/generic/THCTensorMathMagma.h @@ -5,8 +5,6 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) // MAGMA (i.e. CUDA implementation of LAPACK functions) -THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_, - const char *uplo, const char *trans, const char *diag); THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo); THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr); diff --git a/test/test_autograd.py b/test/test_autograd.py index aa2dbbe..7108581 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2172,9 +2172,9 @@ class TestAutograd(TestCase): @skipIfNoLapack def test_trtrs(self): - def _test_with_size(N, C): - A = torch.rand(N, N, requires_grad=True) - b = torch.rand(N, C, requires_grad=True) + def _test_with_size(A_dims, B_dims): + A = torch.rand(*A_dims).requires_grad_() + b = torch.rand(*B_dims).requires_grad_() for upper, transpose, unitriangular in product((True, False), repeat=3): def func(A, b): @@ -2183,8 +2183,10 @@ class TestAutograd(TestCase): gradcheck(func, [A, b]) gradgradcheck(func, [A, b]) - _test_with_size(S, S + 1) - _test_with_size(S, S - 1) + _test_with_size((3, 3), (3, 4)) + _test_with_size((3, 3), (3, 2)) + _test_with_size((2, 3, 3), (2, 3, 4)) + _test_with_size((2, 3, 3), (2, 3, 2)) @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") def test_fft_ifft_rfft_irfft(self): diff --git a/test/test_cuda.py b/test/test_cuda.py index b89de2a..59c50ce 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2567,6 +2567,14 @@ class TestCuda(TestCase): def test_trtrs(self): _TestTorchMixin._test_trtrs(self, lambda t: t.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_trtrs_batched(self): + _TestTorchMixin._test_trtrs_batched(self, lambda t: t.cuda()) + + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_trtrs_batched_dims(self): + _TestTorchMixin._test_trtrs_batched_dims(self, lambda t: t.cuda()) + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_get_set_rng_state_all(self): states = torch.cuda.get_rng_state_all() diff --git a/test/test_torch.py b/test/test_torch.py index 6d2c5e1..46e6d98 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4936,6 +4936,93 @@ class _TestTorchMixin(object): def test_trtrs(self): self._test_trtrs(self, lambda t: t) + @staticmethod + def _test_trtrs_batched(self, cast): + def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular): + A = cast(torch.randn(*A_dims)) + A = A.triu() if upper else A.tril() + if unitriangular: + A.diagonal(dim1=-2, dim2=-1).fill_(1.) + b = cast(torch.randn(*b_dims)) + return A, b + + for upper, transpose, unitriangular in product([True, False], repeat=3): + # test against trtrs: one batch with all possible arguments + A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular) + x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0), + upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + x = torch.trtrs(b, A, + upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + self.assertEqual(x, x_exp.unsqueeze(0)) + + # test against trtrs in a loop: four batches with all possible arguments + A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular) + x_exp_list = [] + for i in range(4): + x_exp = torch.trtrs(b[i], A[i], + upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + x_exp_list.append(x_exp) + x_exp = torch.stack(x_exp_list) + + x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + self.assertEqual(x, x_exp) + + # basic correctness test + A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular) + x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + if transpose: + self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12) + else: + self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12) + + @skipIfNoLapack + def test_trtrs_batched(self): + _TestTorchMixin._test_trtrs_batched(self, lambda t: t) + + @staticmethod + def _test_trtrs_batched_dims(self, cast): + if not TEST_SCIPY: + return + + from scipy.linalg import solve_triangular as tri_solve + + def scipy_tri_solve_batched(A, B, upper, trans, diag): + batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] + single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] + expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), + torch.Size(batch_dims_B))) + expand_A = np.broadcast_to(A, expand_dims + single_dim_A) + expand_B = np.broadcast_to(B, expand_dims + single_dim_B) + flat_A = expand_A.reshape((-1,) + single_dim_A) + flat_B = expand_B.reshape((-1,) + single_dim_B) + flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) + for a, b in zip(flat_A, flat_B)]) + return flat_X.reshape(expand_B.shape) + + def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular): + A = torch.randn(*A_dims) + A = A.triu() if upper else A.tril() + if unitriangular: + A.diagonal(dim1=-2, dim2=-1).fill_(1.) + b = torch.randn(*b_dims) + x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(), + upper, transpose, unitriangular)) + A, b = cast(A), cast(b) + x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + + self.assertEqual(x, cast(x_exp)) + + for upper, transpose, unitriangular in product([True, False], repeat=3): + # test against scipy.linalg.solve_triangular + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper, transpose, unitriangular) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), cast, upper, transpose, unitriangular) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), cast, upper, transpose, unitriangular) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular) # broadcasting A & b + + @skipIfNoLapack + def test_trtrs_batched_dims(self): + self._test_trtrs_batched_dims(self, lambda t: t) + @skipIfNoLapack def test_gels(self): def _test_underdetermined(a, b, expectedNorm): diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 8ed6eeb..e71d5fe 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -1748,7 +1748,7 @@ std::tuple trtrs_backward( if (grad_x.defined()) { grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular)); if (output_mask[1]) { - grad_a = transpose ? -x.mm(grad_b.t()) : -grad_b.mm(x.t()); + grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2)); if (upper) { grad_a = grad_a.triu((int) unitriangular); } else { diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b6d8f61..75b6a32 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5203,29 +5203,33 @@ and multiple right-hand sides :attr:`b`. In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular with the default keyword arguments. +`torch.trtrs(b, A)` can take in 2D inputs `b, A` or inputs that are +batches of 2D matrices. If the inputs are batches, then returns +batched outputs `X` + +.. note:: + + The :attr:`out` keyword only supports 2D matrix inputs, that is, + `b, A` must be 2D matrices. + Args: - A (Tensor): the input triangular coefficient matrix - b (Tensor): multiple right-hand sides. Each column of :math:`b` is a - right-hand side for the system of equations. + A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)` + where :math:`*` is zero or more batch dimensions + b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where + :math:`*` is zero of more batch dimensions upper (bool, optional): whether to solve the upper-triangular system - of equations (default) or the lower-triangular system of equations. Default: True. + of equations (default) or the lower-triangular system of equations. Default: ``True``. transpose (bool, optional): whether :math:`A` should be transposed before - being sent into the solver. Default: False. + being sent into the solver. Default: ``False``. unitriangular (bool, optional): whether :math:`A` is unit triangular. If True, the diagonal elements of :math:`A` are assumed to be - 1 and not referenced from :math:`A`. Default: False. + 1 and not referenced from :math:`A`. Default: ``False``. Returns: A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X` is the solution to :math:`AX = b` (or whatever variant of the system of equations, depending on the keyword arguments.) -Shape: - - A: :math:`(N, N)` - - b: :math:`(N, C)` - - output[0]: :math:`(N, C)` - - output[1]: :math:`(N, N)` - Examples:: >>> A = torch.randn(2, 2).triu() diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 650cf8a..b6a33dc 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -22,8 +22,7 @@ from .laplace import Laplace from .logistic_normal import LogisticNormal from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet, _batch_lowrank_mahalanobis) -from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis, - _batch_trtrs_lower) +from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis) from .normal import Normal from .one_hot_categorical import OneHotCategorical from .pareto import Pareto @@ -313,7 +312,7 @@ def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) / q._unbroadcasted_cov_diag.unsqueeze(-2)) - A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril) + A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0] term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) @@ -340,7 +339,7 @@ def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) / q._unbroadcasted_cov_diag.unsqueeze(-2)) - A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril) + A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0] term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1)) term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) @@ -368,8 +367,8 @@ def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): (n, p.cov_factor.size(-1))) p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()) .expand(combined_batch_shape + (n, n))) - term21 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_factor, q_scale_tril)) - term22 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_diag, q_scale_tril)) + term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=False)[0]) + term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=False)[0]) term2 = term21 + term22 return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) @@ -388,7 +387,7 @@ def _kl_multivariatenormal_multivariatenormal(p, q): n = p.event_shape[0] q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) - term2 = _batch_trace_XXT(_batch_trtrs_lower(p_scale_tril, q_scale_tril)) + term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=False)[0]) term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) return half_term1 + 0.5 * (term2 + term3 - n) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index c1a5016..6ba31eb 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -3,8 +3,7 @@ import math import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.multivariate_normal import (_batch_mahalanobis, _batch_mv, - _batch_trtrs_lower) +from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv from torch.distributions.utils import _standard_normal, lazy_property @@ -163,7 +162,7 @@ class LowRankMultivariateNormal(Distribution): # where :math:`C` is the capacitance matrix. Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2) / self._unbroadcasted_cov_diag.unsqueeze(-2)) - A = _batch_trtrs_lower(Wt_Dinv, self._capacitance_tril) + A = torch.trtrs(Wt_Dinv, self._capacitance_tril, upper=False)[0] precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - torch.matmul(A.transpose(-1, -2), A)) return precision_matrix.expand(self._batch_shape + self._event_shape + diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 6055459..6efd8c9 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -20,17 +20,6 @@ def _batch_mv(bmat, bvec): return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) -def _batch_trtrs_lower(bb, bA): - """ - Applies `torch.trtrs` for batches of matrices. `bb` and `bA` should have - the same batch shape. - """ - flat_b = bb.reshape((-1,) + bb.shape[-2:]) - flat_A = bA.reshape((-1,) + bA.shape[-2:]) - flat_X = torch.stack([torch.trtrs(b, A, upper=False)[0] for b, A in zip(flat_b, flat_A)]) - return flat_X.reshape(bb.shape) - - def _batch_mahalanobis(bL, bx): r""" Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` @@ -65,7 +54,7 @@ def _batch_mahalanobis(bL, bx): flat_L = bL.reshape(-1, n, n) # shape = b x n x n flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c - M_swap = _batch_trtrs_lower(flat_x_swap, flat_L).pow(2).sum(-2) # shape = b x c + M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c M = M_swap.t() # shape = c x b # Now we revert the above reshape and permute operators. -- 2.7.4