From 566ee1217fc45958e647a103f0295bcf0b62bc7d Mon Sep 17 00:00:00 2001 From: lezcano Date: Tue, 7 Sep 2021 17:22:49 -0700 Subject: [PATCH] Use trsm for triangular_solve in CPU (#63567) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63567 The current implementation called trtrs for CPU and trsm for CUDA. See https://github.com/pytorch/pytorch/issues/56326#issuecomment-825496115 for a discussion on the differences between these two functions and why we prefer trsm vs trtrs on CUDA. This PR also exposes the `side` argument of this function which is used in the second PR of this stack to optimise the number copies one needs to make when preparing the arguments to be sent to the backends. It also changes the use of `bool`s to a common enum type to represent whether a matrix is transposed / conj transposed, etc. This makes the API consistent, as before, the behaviour of these functions with `transpose=True` and `conjugate_transpose=True` it was not well defined. Functions to transform this type into the specific types / chars for the different libraries are provided under the names `to_blas`, `to_lapack`, `to_magma`, etc. This is the first of a stack of PRs that aim to improve the performance of `linalg.solve_triangular`. `trsm` has an extra parameter (`side`), which allows to ellide the copy of the triangular matrix in many cases. Fixes https://github.com/pytorch/pytorch/issues/56326 Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D30566479 Pulled By: mruberry fbshipit-source-id: 3831af9b51e09fbfe272c17c88c21ecf45413212 --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 84 ++++++------ aten/src/ATen/native/BatchLinearAlgebra.h | 17 +-- aten/src/ATen/native/BatchLinearAlgebraKernel.cpp | 52 +++----- aten/src/ATen/native/CPUBlas.cpp | 25 +--- aten/src/ATen/native/CPUBlas.h | 7 +- aten/src/ATen/native/LinearAlgebra.cpp | 4 +- aten/src/ATen/native/LinearAlgebraUtils.h | 21 ++- .../ATen/native/NaiveConvolutionTranspose2d.cpp | 16 +-- .../ATen/native/NaiveConvolutionTranspose3d.cpp | 16 +-- aten/src/ATen/native/NaiveDilatedConvolution.cpp | 14 +- aten/src/ATen/native/cpu/BlasKernel.cpp | 8 +- aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp | 146 +++++++++------------ .../src/ATen/native/cuda/BatchLinearAlgebraLib.cpp | 63 +++++---- aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h | 8 +- test/test_linalg.py | 11 -- torch/_torch_docs.py | 4 + .../_internal/common_methods_invocations.py | 2 +- 17 files changed, 228 insertions(+), 270 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 498b51b..0471cdb 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -55,12 +55,6 @@ extern "C" void cpotri_(char *uplo, int *n, std::complex *a, int *lda, in extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info); extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info); -// trtrs -extern "C" void ztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); -extern "C" void ctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); -extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); -extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); - // geqrf extern "C" void zgeqrf_(int *m, int *n, std::complex *a, int *lda, std::complex *tau, std::complex *work, int *lwork, int *info); extern "C" void cgeqrf_(int *m, int *n, std::complex *a, int *lda, std::complex *tau, std::complex *work, int *lwork, int *info); @@ -200,6 +194,14 @@ extern "C" void sgelss_(int *m, int *n, int *nrhs, float *work, int *lwork, int *info); #endif +#if AT_BUILD_WITH_BLAS() +// trsm +extern "C" void ztrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex *alpha, std::complex *a, int *lda, std::complex *b, int *ldb); +extern "C" void ctrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex *alpha, std::complex *a, int *lda, std::complex *b, int *ldb); +extern "C" void dtrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, double *alpha, double *a, int *lda, double *b, int *ldb); +extern "C" void strsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, float *alpha, float *a, int *lda, float *b, int *ldb); +#endif + namespace at { namespace native { @@ -318,22 +320,6 @@ template<> void lapackCholeskyInverse(char uplo, int n, float *a, int lda spotri_(&uplo, &n, a, &lda, info); } -template<> void lapackTriangularSolve>(char uplo, char trans, char diag, int n, int nrhs, c10::complex *a, int lda, c10::complex *b, int ldb, int *info) { - ztrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast*>(a), &lda, reinterpret_cast*>(b), &ldb, info); -} - -template<> void lapackTriangularSolve>(char uplo, char trans, char diag, int n, int nrhs, c10::complex *a, int lda, c10::complex *b, int ldb, int *info) { - ctrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast*>(a), &lda, reinterpret_cast*>(b), &ldb, info); -} - -template<> void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) { - dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); -} - -template<> void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) { - strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info); -} - template<> void lapackGeqrf>(int m, int n, c10::complex *a, int lda, c10::complex *tau, c10::complex *work, int lwork, int *info) { zgeqrf_(&m, &n, reinterpret_cast*>(a), &lda, reinterpret_cast*>(tau), reinterpret_cast*>(work), &lwork, info); } @@ -687,6 +673,28 @@ template<> void lapackGelss( } #endif +#if AT_BUILD_WITH_BLAS() +template<> void blasTriangularSolve>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex *a, int lda, c10::complex *b, int ldb) { + std::complex one{1., 0.}; + ztrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast*>(a), &lda, reinterpret_cast*>(b), &ldb); +} + +template<> void blasTriangularSolve>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex *a, int lda, c10::complex *b, int ldb) { + std::complex one{1.f, 0.f}; + ctrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast*>(a), &lda, reinterpret_cast*>(b), &ldb); +} + +template<> void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb) { + auto one = 1.; + dtrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb); +} + +template<> void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb) { + auto one = 1.f; + strsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb); +} +#endif + // Below of the definitions of the functions operating on a batch that are going to be dispatched // in the main helper functions for the linear algebra operations @@ -1635,7 +1643,7 @@ static std::tuple triangular_solve_out_info( result.copy_(other); clone_input.copy_(input); - triangular_solve_stub(input.device().type(), clone_input, result, infos, upper, transpose, /*conjugate_transpose=*/false, unitriangular); + triangular_solve_stub(input.device().type(), clone_input, result, /*left=*/true, upper, transpose ? TransposeType::Transpose : TransposeType::NoTranspose, unitriangular); return std::tuple(result, clone_input); } @@ -3545,17 +3553,7 @@ DEFINE_DISPATCH(lu_solve_stub); DEFINE_DISPATCH(lu_solve_trans_stub); // Supports arbitrary batch dimensions for self and LU_data (implicitly LU_pivots also) -Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, const c10::string_view trans_str) { - auto trans = std::toupper(trans_str[0]); - switch (trans) { - case 'N': - case 'T': - case 'C': - break; - default: - TORCH_CHECK(false, - "lu_solve: wrong `trans` parameter, it must be one of 'N', 'T' or 'C'"); - } +Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, TransposeType trans) { TORCH_CHECK(self.dim() >= 2, "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead"); TORCH_CHECK(LU_data.dim() >= 2, @@ -3598,7 +3596,7 @@ Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& } Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) { - return at::native::_lu_solve_trans(self, LU_data, LU_pivots, "N"); + return at::native::_lu_solve_trans(self, LU_data, LU_pivots, TransposeType::NoTranspose); } Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) { @@ -3748,7 +3746,7 @@ Tensor _det_lu_based_helper_backward_helper( auto lu_clone = lu.clone(); condition_diagonal(lu_clone); - auto trans = self.is_complex() ? 'C' : 'T'; + auto trans = self.is_complex() ? TransposeType::ConjTranspose : TransposeType::Transpose; // d is modified in-place and will contain the result lu_solve_trans_stub(self.device().type(), d, lu_clone, pivs, trans); @@ -3766,8 +3764,6 @@ Tensor _det_lu_based_helper_backward_helper( u.conj_physical_(); } - auto infos = at::zeros({std::max(1, batchCount(self))}, self.options().dtype(kInt)); - // triangular_solve_stub performs operations in-place. // Tensor d will contain the result condition_diagonal(u); @@ -3780,19 +3776,19 @@ Tensor _det_lu_based_helper_backward_helper( // Since u is conjugated in-place in the code above, it is sufficient // to just run triangular_solve with upper=false. triangular_solve_stub( - self.device().type(), u, d, infos, + self.device().type(), u, d, + /*left=*/true, /*upper=*/false, - /*transpose=*/false, - /*conjugate_transpose=*/false, + /*transpose=*/TransposeType::NoTranspose, /*unitriangular=*/false); // After this operation d will contain a row-wise permuted grad wrt to self // The same notes as for the system involving u apply here. triangular_solve_stub( - self.device().type(), l, d, infos, + self.device().type(), l, d, + /*left=*/true, /*upper=*/true, - /*transpose=*/false, - /*conjugate_transpose=*/false, + /*transpose=*/TransposeType::NoTranspose, /*unitriangular=*/true); // multiply by p to restore the row order diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 1d239b9..dba3a41 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -37,9 +37,6 @@ template void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); template -void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb, int* info); - -template void lapackGels(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info); @@ -166,6 +163,11 @@ void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); #endif +#if AT_BUILD_WITH_BLAS() +template +void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); +#endif + using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); DECLARE_DISPATCH(cholesky_fn, cholesky_stub); @@ -210,11 +212,10 @@ DECLARE_DISPATCH(lstsq_fn, lstsq_stub); using triangular_solve_fn = void (*)( Tensor& /*A*/, - Tensor& /*b*/, - Tensor& /*infos*/, + Tensor& /*B*/, + bool /*left*/, bool /*upper*/, - bool /*transpose*/, - bool /*conjugate_transpose*/, + TransposeType /*transpose*/, bool /*unitriangular*/); DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); @@ -235,7 +236,7 @@ using lu_solve_trans_fn = void (*)( const Tensor& /*b*/, const Tensor& /*lu*/, const Tensor& /*pivots*/, - char /*trans*/); + TransposeType /*trans*/); DECLARE_DISPATCH(lu_solve_trans_fn, lu_solve_trans_stub); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index db67e93..dee835a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -10,7 +10,6 @@ namespace at { namespace native { namespace { - /* Computes the Cholesky decomposition of matrices stored in `input`. This is an in-place routine and the content of 'input' is overwritten with the result. @@ -791,53 +790,45 @@ X and B are n-by-nrhs matrices, A is a unit, or non-unit, upper or lower triangu and op(A) is one of op(A) = A or op(A) = A^T or op(A) = A^H. This is an in-place routine, content of 'B' is overwritten. 'upper' controls the portion of input matrix to consider in computations, -'transpose' if true then op(A) = A^T, +'transpose' chooses op(A) 'unitriangular' if true then the diagonal elements of A are assumed to be 1 and the actual diagonal values are not used. -'infos' is an int Tensor containing error codes for each matrix in the batched input. -For more information see LAPACK's documentation for TRTRS routine. */ template -void apply_triangular_solve(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { -#if !AT_BUILD_WITH_LAPACK() +void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { +#if !AT_BUILD_WITH_BLAS() TORCH_CHECK( false, "Calling torch.triangular_solve on a CPU tensor requires compiling ", - "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); + "PyTorch with BLAS. Please use PyTorch built with BLAS support."); #else char uplo = upper ? 'U' : 'L'; - char trans = transpose ? 'T' : 'N'; - trans = conjugate_transpose ? 'C' : trans; char diag = unitriangular ? 'U' : 'N'; + char side = left ? 'L' : 'R'; + const char trans = to_blas(transpose); auto A_data = A.data_ptr(); auto B_data = B.data_ptr(); auto A_mat_stride = matrixStride(A); auto B_mat_stride = matrixStride(B); auto batch_size = batchCount(A); - auto n = A.size(-2); - auto nrhs = B.size(-1); - auto lda = std::max(1, n); - auto infos_data = infos.data_ptr(); + // This allows to pass rectangular A and B when left = True + auto m = left ? A.size(-1) : B.size(-2); + auto n = B.size(-1); + auto lda = std::max(1, A.size(-2)); + auto ldb = std::max(1, B.size(-2)); for (const auto i : c10::irange(batch_size)) { scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* B_working_ptr = &B_data[i * B_mat_stride]; - int* info_working_ptr = &infos_data[i]; - lapackTriangularSolve(uplo, trans, diag, n, nrhs, A_working_ptr, lda, B_working_ptr, lda, info_working_ptr); - // The current behaviour for linear algebra functions to raise an error if something goes wrong - // or input doesn't satisfy some requirement - // therefore return early since further computations will be wasted anyway - if (*info_working_ptr != 0) { - return; - } + blasTriangularSolve(side, uplo, trans, diag, m, n, A_working_ptr, lda, B_working_ptr, ldb); } #endif } -void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { +void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cpu", [&]{ - apply_triangular_solve(A, B, infos, upper, transpose, conjugate_transpose, unitriangular); + apply_triangular_solve(A, B, left, upper, transpose, unitriangular); }); } @@ -904,7 +895,7 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b For further details, please see the LAPACK documentation for GETRS. */ template -void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans = 'N') { +void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) { #if !AT_BUILD_WITH_LAPACK() TORCH_CHECK( false, @@ -913,6 +904,7 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, cha #else auto b_data = b.data_ptr(); auto lu_data = lu.data_ptr(); + const auto trans = to_blas(transpose); auto pivots_data = pivots.data_ptr(); auto b_stride = matrixStride(b); auto lu_stride = matrixStride(lu); @@ -940,22 +932,14 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, cha } // This is a type dispatching helper function for 'apply_lu_solve' -void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) { - switch (trans) { - case 'N': - case 'T': - case 'C': - break; - default: - TORCH_INTERNAL_ASSERT(false, "lu_solve_trans_cpu: wrong value for `trans`, it must be one of 'N', 'T', 'C'"); - } +void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cpu", [&]{ apply_lu_solve(b, lu, pivots, trans); }); } void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) { - lu_solve_trans_kernel(b, lu, pivots, 'N'); + lu_solve_trans_kernel(b, lu, pivots, TransposeType::NoTranspose); } } // anonymous namespace diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index f14e4dc..9932aa0 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -39,7 +39,7 @@ void normalize_last_dims( *ldc = m; } - if(transa != NoTranspose) { + if(transa != TransposeType::NoTranspose) { if (m == 1) { *lda = k; } @@ -47,7 +47,7 @@ void normalize_last_dims( *lda = m; } - if(transb != NoTranspose) { + if(transb != TransposeType::NoTranspose) { if (k == 1) { *ldb = n; } @@ -63,8 +63,8 @@ bool use_blas_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, int64_t &lda, int64_t &ldb, int64_t &ldc) { - const bool transa_ = transa != NoTranspose; - const bool transb_ = transb != NoTranspose; + const bool transa_ = transa != TransposeType::NoTranspose; + const bool transb_ = transb != TransposeType::NoTranspose; return ( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) && @@ -73,23 +73,12 @@ bool use_blas_gemm( (ldc >= std::max(int64_t{1}, m))); } -#if AT_BUILD_WITH_BLAS() -char to_blas(TransposeType trans) { - switch (trans) { - case Transpose: return 't'; - case NoTranspose: return 'n'; - case ConjTranspose: return 'c'; - } - TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); -} -#endif // AT_BUILD_WITH_BLAS - #ifdef USE_FBGEMM fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { switch (trans) { - case Transpose: return fbgemm::matrix_op_t::Transpose; - case NoTranspose: return fbgemm::matrix_op_t::NoTranspose; - case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); + case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose; + case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose; + case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); } TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); } diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 3a483e4..7002a57 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -9,12 +10,6 @@ namespace at { namespace native { namespace cpublas { -enum TransposeType { - Transpose, - NoTranspose, - ConjTranspose, -}; - namespace internal { void normalize_last_dims( TransposeType transa, TransposeType transb, diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 0576bd6..57346ae 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1074,8 +1074,8 @@ static void addmm_impl_cpu_( result.scalar_type(), "addmm_impl_cpu_", [&]{ at::native::cpublas::gemm( - transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, - transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose, + transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose, + transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose, m, n, k, alpha.to(), a.data_ptr(), lda, diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index abbf82c..d4adb10 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -12,6 +12,23 @@ namespace at { namespace native { +// Used as an interface between the different BLAS-like libraries +enum class TransposeType { + NoTranspose, + Transpose, + ConjTranspose, +}; + +// Transforms TransposeType into the BLAS / LAPACK format +static char to_blas(TransposeType trans) { + switch (trans) { + case TransposeType::Transpose: return 'T'; + case TransposeType::NoTranspose: return 'N'; + case TransposeType::ConjTranspose: return 'C'; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} + /* * Clones a Tensor so that the following conditions hold: * If we think of a Tensor of having size (B, M, N), where B is any number @@ -19,8 +36,8 @@ namespace at { namespace native { * - Each (M, N) matrix is in column major form * - Let Tensor P have size (B, M, N) and Q have size (B, M', N'). * Then when laid out in memory, the M by N matrix starting at - * P.data_ptr()[b * M * N] is of the same corresponding batch as the M' by N' - * matrix starting at Q.data_ptr()[b * M' * N']. + * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N' + * matrix starting at Q.data_ptr()[B * M' * N']. */ static inline Tensor cloneBatchedColumnMajor(const Tensor& src) { // If src is already in batched column major format, then diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index 6ae81e8..f121117 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -317,8 +317,8 @@ void slow_conv_transpose2d_out_cpu_template( // Do GEMM (note: this is a bit confusing because gemm assumes // column-major matrices) cpublas::gemm( - cpublas::NoTranspose, - cpublas::Transpose, + TransposeType::NoTranspose, + TransposeType::Transpose, n, m, k, @@ -360,8 +360,8 @@ void slow_conv_transpose2d_out_cpu_template( // column-major matrices) if (bias.defined()) { cpublas::gemm( - cpublas::Transpose, - cpublas::NoTranspose, + TransposeType::Transpose, + TransposeType::NoTranspose, n_, m_, k_, @@ -536,8 +536,8 @@ static void slow_conv_transpose2d_backward_out_cpu_template( ? grad_columns.data_ptr() : grad_output_n.data_ptr(); cpublas::gemm( - cpublas::NoTranspose, - cpublas::NoTranspose, + TransposeType::NoTranspose, + TransposeType::NoTranspose, n, m, k, @@ -738,8 +738,8 @@ void slow_conv_transpose2d_acc_grad_parameters_cpu( ? columns.data_ptr() : grad_output_n.data_ptr(); cpublas::gemm( - cpublas::Transpose, - cpublas::NoTranspose, + TransposeType::Transpose, + TransposeType::NoTranspose, n, m, k, diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index dcde12e..9266047 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -317,8 +317,8 @@ void slow_conv_transpose3d_out_cpu_template( // Do GEMM (note: this is a bit confusing because gemm assumes // column-major matrices) cpublas::gemm( - cpublas::NoTranspose, - cpublas::Transpose, + TransposeType::NoTranspose, + TransposeType::Transpose, n, m, k, @@ -366,8 +366,8 @@ void slow_conv_transpose3d_out_cpu_template( // column-major matrices) if (bias.defined()) { cpublas::gemm( - cpublas::Transpose, - cpublas::NoTranspose, + TransposeType::Transpose, + TransposeType::NoTranspose, n_, m_, k_, @@ -579,8 +579,8 @@ void slow_conv_transpose3d_backward_out_cpu_template( ? grad_columns.data_ptr() : grad_output_n.data_ptr(); cpublas::gemm( - cpublas::NoTranspose, - cpublas::NoTranspose, + TransposeType::NoTranspose, + TransposeType::NoTranspose, n, m, k, @@ -819,8 +819,8 @@ void slow_conv_transpose3d_acc_grad_parameters_cpu( ? columns.data_ptr() : grad_output_n.data_ptr(); cpublas::gemm( - cpublas::Transpose, - cpublas::NoTranspose, + TransposeType::Transpose, + TransposeType::NoTranspose, n, m, k, diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index 3274f13..ab99d23 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -1,5 +1,3 @@ - - #include #include #include @@ -273,8 +271,8 @@ void slow_conv_dilated_all_cpu_template( op(A) = 'n', op(B) = 'n', alpha=1, beta=1 */ cpublas::gemm( - /*transa=*/cpublas::NoTranspose, - /*transb=*/cpublas::NoTranspose, + /*transa=*/TransposeType::NoTranspose, + /*transb=*/TransposeType::NoTranspose, /* m=*/columns.size(1), /* n=*/nOutputPlane, /* k=*/columns.size(0), @@ -317,8 +315,8 @@ void slow_conv_dilated_all_cpu_template( op(A) = 'n', op(B) = 't', alpha=1, beta=0 */ cpublas::gemm( - /*transa=*/cpublas::NoTranspose, - /*transb=*/cpublas::Transpose, + /*transa=*/TransposeType::NoTranspose, + /*transb=*/TransposeType::Transpose, /* m=*/columns.size(1), /* n=*/columns.size(0), /* k=*/nOutputPlane, @@ -382,8 +380,8 @@ void slow_conv_dilated_all_cpu_template( op(B) = 'n', alpha=scale, beta=1 */ cpublas::gemm( - /*transa=*/cpublas::Transpose, - /*transb=*/cpublas::NoTranspose, + /*transa=*/TransposeType::Transpose, + /*transb=*/TransposeType::NoTranspose, /* m=*/columns.size(0), /* n=*/nOutputPlane, /* k=*/columns.size(1), diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 4bf2327..1897ff8 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -156,13 +156,13 @@ void gemm_core_( const scalar_t *b, int64_t ldb, scalar_t beta, scalar_t *c, int64_t ldc) { - if(transa == NoTranspose && transb == NoTranspose) { + if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else if(transa == Transpose && transb != Transpose) { + } else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) { gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else if(transa == NoTranspose && transb == Transpose) { + } else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) { gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else { // transa == Transpose && transb == Transpose + } else { // transa == TransposeType::Transpose && transb == TransposeType::Transpose gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp index 7fdc55d..c6225c6 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp @@ -97,7 +97,7 @@ void magmaCholeskyBatched( template void magmaTriangularSolveBatched( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue); @@ -667,29 +667,29 @@ void magmaCholeskyBatched>( template<> void magmaTriangularSolveBatched( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { - magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); + magmablas_dtrsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } template<> void magmaTriangularSolveBatched( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { - magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); + magmablas_strsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); } template<> void magmaTriangularSolveBatched>( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magmaDoubleComplex alpha({1, 0}); - magmablas_ztrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + magmablas_ztrsm_batched(side, uplo, trans, diag, m, n, alpha, reinterpret_cast(dA_array), ldda, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); @@ -697,11 +697,11 @@ void magmaTriangularSolveBatched>( template<> void magmaTriangularSolveBatched>( - magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, + magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n, c10::complex** dA_array, magma_int_t ldda, c10::complex** dB_array, magma_int_t lddb, magma_int_t batchsize, const MAGMAQueue& magma_queue) { magmaFloatComplex alpha({1, 0}); - magmablas_ctrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha, + magmablas_ctrsm_batched(side, uplo, trans, diag, m, n, alpha, reinterpret_cast(dA_array), ldda, reinterpret_cast(dB_array), lddb, batchsize, magma_queue.get_queue()); AT_CUDA_CHECK(cudaGetLastError()); @@ -1224,8 +1224,15 @@ void checkMagmaInternalError(magma_int_t info, const std::string& magma_function ", when calling ", magma_function_name); } +magma_trans_t to_magma(TransposeType trans) { + switch (trans) { + case TransposeType::NoTranspose: return MagmaNoTrans; + case TransposeType::Transpose: return MagmaTrans; + case TransposeType::ConjTranspose: return MagmaConjTrans; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} } // anonymous namespace - #endif // USE_MAGMA #define ALLOCATE_ARRAY(name, type, size) \ @@ -1950,27 +1957,28 @@ REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template -static void apply_triangular_solve_batched(Tensor& A, Tensor& b, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { +static void apply_triangular_solve_batched_magma(Tensor& A, Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) { #ifndef USE_MAGMA AT_ERROR("triangular_solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; - magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans; - trans = conjugate_transpose ? MagmaConjTrans : trans; + magma_trans_t trans = to_magma(transpose); magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit; + magma_side_t side = left ? MagmaLeft : MagmaRight; auto A_data = A.data_ptr(); auto b_data = b.data_ptr(); - magma_int_t m = magma_int_cast(A.size(-2), "A.size(-2)"); - magma_int_t n = magma_int_cast(A.size(-1), "A.size(-1)"); - magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)"); + // This allows to pass rectangular A and b when left = True + magma_int_t m = magma_int_cast(left ? A.size(-1) : b.size(-2), "m"); + magma_int_t n = magma_int_cast(b.size(-1), "n"); // magma returns early if m <= 0 || n <= 0 for magmaTriangularSolveBatched // magmaTriangularSolve is calling cuBLAS and it prints // ** On entry to DTRSM parameter number 9 had an illegal value // so let's use proper lda parameter here - magma_int_t lda = std::max(1, m); - magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount"); + magma_int_t lda = std::max(1, A.size(-2)); + magma_int_t ldb = std::max(1, b.size(-2)); + magma_int_t batch_size = magma_int_cast(batchCount(A), "batch_size"); auto A_mat_stride = matrixStride(A); auto b_mat_stride = matrixStride(b); @@ -1990,7 +1998,7 @@ AT_ERROR("triangular_solve: MAGMA library not found in " MAGMAQueue magma_queue(b.get_device()); constexpr int64_t batch_limit = 65535; - // Compute as many batches of 65535 possible + // Compute as many batches of 65535 as possible // The number of "mini"-batches are floor(batch_size / batch_limit) // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves int64_t mini_batches = batch_size / batch_limit; @@ -2000,40 +2008,39 @@ AT_ERROR("triangular_solve: MAGMA library not found in " scalar_t** b_array_cur = &b_array[mini_idx]; magmaTriangularSolveBatched( - uplo, trans, diag, n, nrhs, A_array_cur, - lda, b_array_cur, lda, batch_limit, magma_queue); + side, uplo, trans, diag, m, n, A_array_cur, + lda, b_array_cur, ldb, batch_limit, magma_queue); } // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit // which concisely is equal to batch_size % batch_limit if (batch_size % batch_limit != 0) { magmaTriangularSolveBatched( - uplo, trans, diag, n, nrhs, &A_array[mini_idx], - lda, &b_array[mini_idx], lda, batch_size % batch_limit, magma_queue); + side, uplo, trans, diag, m, n, &A_array[mini_idx], + lda, &b_array[mini_idx], ldb, batch_size % batch_limit, magma_queue); } #endif } -void triangular_solve_batched_magma(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { - (void)infos; // unused +void triangular_solve_batched_magma(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{ - apply_triangular_solve_batched(A, B, upper, transpose, conjugate_transpose, unitriangular); + apply_triangular_solve_batched_magma(A, B, left, upper, transpose, unitriangular); }); } -void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { +void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { // For batches smaller than 8 and matrix sizes larger than 64x64 cuBLAS forloop is faster than batched version if (batchCount(A) <= 8 && A.size(-1) >= 64) { - triangular_solve_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular); + triangular_solve_cublas(A, B, left, upper, transpose, unitriangular); } else { #ifndef USE_MAGMA - triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular); + triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular); #else // cuBLAS batched is faster than MAGMA batched up until 512x512, after that MAGMA is faster if (A.size(-1) <= 512) { - triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular); + triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular); } else { - triangular_solve_batched_magma(A, B, infos, upper, transpose, conjugate_transpose, unitriangular); + triangular_solve_batched_magma(A, B, left, upper, transpose, unitriangular); } #endif // USE_MAGMA } @@ -2724,21 +2731,6 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -#ifdef USE_MAGMA -magma_trans_t _get_magma_trans(char trans) { - switch (trans) { - case 'N': - return MagmaNoTrans; - case 'T': - return MagmaTrans; - case 'C': - return MagmaConjTrans; - default: - return MagmaNoTrans; - } -} -#endif - /* Solves the matrix equation A X = B X and B are n-by-nrhs matrices, A is represented using the LU factorization. @@ -2754,14 +2746,14 @@ magma_trans_t _get_magma_trans(char trans) { For further details, please see the MAGMA documentation for magma_dgetrs_gpu. */ template -static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) { +static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) { #ifndef USE_MAGMA TORCH_CHECK( false, "Calling torch.lu_solve on a CUDA tensor requires compiling ", "PyTorch with MAGMA. lease rebuild with MAGMA."); #else - auto trans = _get_magma_trans(lapack_trans); + auto trans = to_magma(transpose); auto b_data = b.data_ptr(); auto lu_data = lu.data_ptr(); @@ -2808,14 +2800,14 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const For further details, please see the MAGMA documentation for magma_dgetrs_batched. */ template -static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) { +static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) { #ifndef USE_MAGMA TORCH_CHECK( false, "Calling torch.lu_solve on a CUDA tensor requires compiling ", "PyTorch with MAGMA. lease rebuild with MAGMA."); #else - auto trans = _get_magma_trans(lapack_trans); + auto trans = to_magma(transpose); auto b_data = b.data_ptr(); auto lu_data = lu.data_ptr(); @@ -2869,34 +2861,20 @@ static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, cons #endif } -static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) { +static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_batched_magma", [&]{ - apply_lu_solve_batched_magma(b, lu, pivots, lapack_trans); + apply_lu_solve_batched_magma(b, lu, pivots, trans); }); } -static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) { +static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_looped_magma", [&]{ - apply_lu_solve_looped_magma(b, lu, pivots, lapack_trans); + apply_lu_solve_looped_magma(b, lu, pivots, trans); }); } -#if defined(USE_CUSOLVER) || defined(CUDART_VERSION) -cublasOperation_t _get_cublas_trans(char trans) { - switch (trans) { - case 'N': - return CUBLAS_OP_N; - case 'T': - return CUBLAS_OP_T; - case 'C': - return CUBLAS_OP_C; - default: - return CUBLAS_OP_N; - } -} -#endif -static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) { +static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) { auto batch_size = batchCount(lu); auto m = lu.size(-2); auto b2 = b.size(-1); @@ -2904,7 +2882,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten // heuristics determined from tests dicussed in https://github.com/pytorch/pytorch/pull/59148 #ifdef USE_CUSOLVER if ((batch_size == 1 && m > 512) || (batch_size <= 8 && over_magma_dim_limit)) { - lu_solve_looped_cusolver(b, lu, pivots, _get_cublas_trans(trans)); + lu_solve_looped_cusolver(b, lu, pivots, trans); } #else if (batch_size == 1) { @@ -2913,7 +2891,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten #endif // ifdef USE_CUSOLVER #ifdef CUDART_VERSION else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) { - lu_solve_batched_cublas(b, lu, pivots, _get_cublas_trans(trans)); + lu_solve_batched_cublas(b, lu, pivots, trans); } #endif // ifdef CUDART_VERSION else { @@ -2924,7 +2902,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_dispatch); static void lu_solve_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots) { - lu_solve_trans_dispatch(b, lu, pivots, 'N'); + lu_solve_trans_dispatch(b, lu, pivots, TransposeType::NoTranspose); } REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_dispatch); @@ -2975,7 +2953,7 @@ void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) { }); } -void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) { +void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/) { // The steps for using the QR decomposition for solving least squares problems // are outlined here https://en.wikipedia.org/wiki/QR_decomposition#Using_for_solution_to_linear_inverse_problems auto m = A.size(-2); @@ -3012,15 +2990,13 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) { ormqr_kernel(A_broadcasted, tau_broadcasted, B, /*left=*/true, /*transpose=*/true); // Step 3: solve R X = B - bool upper = true; - bool transpose = false; - bool conjugate_transpose = false; - bool unitriangular = false; triangular_solve_kernel( const_cast(A_broadcasted), const_cast(B), - const_cast(infos), - upper, transpose, conjugate_transpose, unitriangular); + /*left=*/true, + /*upper=*/true, + /*transpose=*/TransposeType::NoTranspose, + /*unitriangular=*/false); } else { // underdetermined case Tensor Ah = cloneBatchedColumnMajor(A.conj().transpose(-2, -1)); @@ -3036,15 +3012,15 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) { Tensor Ah_broadcasted = is_fortran_contiguous ? Ah_expanded : cloneBatchedColumnMajor(Ah_expanded); // Step 2: R^H Z = B - bool upper = true; - bool transpose = true; - bool conjugate_transpose = true; - bool unitriangular = false; + const auto trans = Ah_broadcasted.is_complex() ? TransposeType::ConjTranspose + : TransposeType::Transpose; triangular_solve_kernel( const_cast(Ah_broadcasted), const_cast(B), - const_cast(infos), - upper, transpose, conjugate_transpose, unitriangular); + /*left=*/true, + /*upper=*/true, + /*transpose=*/trans, + /*unitriangular=*/false); // B matrix has the size max(m, n) x nrhs // triangular_solve_kernel writes its output into the first m rows of B leaving the rest untouched diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp index 13d67e5..f9b4322 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp @@ -15,6 +15,15 @@ namespace at { namespace native { +cublasOperation_t to_cublas(TransposeType trans) { + switch (trans) { + case TransposeType::NoTranspose: return CUBLAS_OP_N; + case TransposeType::Transpose: return CUBLAS_OP_T; + case TransposeType::ConjTranspose: return CUBLAS_OP_C; + } + TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); +} + // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices // 'input' must be a contiguous tensor template @@ -70,10 +79,11 @@ void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) { } template -static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) { +static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) { #ifndef CUDART_VERSION TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.") #else + const auto trans = to_cublas(transpose); auto pivots_data = pivots.data_ptr(); auto batch_size = cuda_int_cast(batchCount(lu), "batch_size");; @@ -94,29 +104,29 @@ static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, con #endif } -void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) { +void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(lu.scalar_type(), "lu_solve_cublas", [&]{ apply_lu_solve_batched_cublas(b, lu, pivots, trans); }); } template -static void apply_triangular_solve(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { +static void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - trans = conjugate_transpose ? CUBLAS_OP_C : trans; + const auto trans = to_cublas(transpose); cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - cublasSideMode_t side = CUBLAS_SIDE_LEFT; + cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; auto A_data = A.data_ptr(); auto B_data = B.data_ptr(); auto A_mat_stride = matrixStride(A); auto B_mat_stride = matrixStride(B); auto batch_size = batchCount(A); - auto m = cuda_int_cast(A.size(-2), "m"); - auto n = cuda_int_cast(A.size(-1), "n"); - auto nrhs = cuda_int_cast(B.size(-1), "nrhs"); - auto lda = std::max(1, m); + // This allows to pass rectangular A and B when left = True + auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m"); + auto n = cuda_int_cast(B.size(-1), "n"); + auto lda = std::max(1, cuda_int_cast(A.size(-2), "lda")); + auto ldb = std::max(1, cuda_int_cast(B.size(-2), "ldb")); auto alpha = scalar_t{1}; @@ -124,30 +134,29 @@ static void apply_triangular_solve(Tensor& A, Tensor& B, bool upper, bool transp scalar_t* A_working_ptr = &A_data[i * A_mat_stride]; scalar_t* B_working_ptr = &B_data[i * B_mat_stride]; auto handle = at::cuda::getCurrentCUDABlasHandle(); - at::cuda::blas::trsm(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_working_ptr, lda, B_working_ptr, lda); + at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb); } } -void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { - (void)infos; // unused +void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{ - apply_triangular_solve(A, B, upper, transpose, conjugate_transpose, unitriangular); + apply_triangular_solve(A, B, left, upper, transpose, unitriangular); }); } template -static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { +static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - trans = conjugate_transpose ? CUBLAS_OP_C : trans; + const auto trans = to_cublas(transpose); cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; - cublasSideMode_t side = CUBLAS_SIDE_LEFT; + cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; auto batch_size = cuda_int_cast(batchCount(A), "batch_size"); - auto m = cuda_int_cast(A.size(-2), "m"); - auto n = cuda_int_cast(A.size(-1), "n"); - auto nrhs = cuda_int_cast(B.size(-1), "nrhs"); - auto lda = std::max(1, m); + // This allows to pass rectangular A and B when left = True + auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m"); + auto n = cuda_int_cast(B.size(-1), "n"); + auto lda = std::max(1, cuda_int_cast(A.size(-2), "lda")); + auto ldb = std::max(1, cuda_int_cast(B.size(-2), "ldb")); auto alpha = scalar_t{1}; @@ -158,13 +167,12 @@ static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, boo auto B_ptr_array_data = reinterpret_cast(B_ptr_array.data_ptr()); auto handle = at::cuda::getCurrentCUDABlasHandle(); - at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_ptr_array_data, lda, B_ptr_array_data, lda, batch_size); + at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size); } -void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) { - (void)infos; // unused +void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{ - apply_triangular_solve_batched(A, B, upper, transpose, conjugate_transpose, unitriangular); + apply_triangular_solve_batched(A, B, left, upper, transpose, unitriangular); }); } @@ -1297,8 +1305,9 @@ void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& } } -void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) { +void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cusolver", [&] { + const auto trans = to_cublas(transpose); int n = cuda_int_cast(lu.size(-2), "n"); int nrhs = cuda_int_cast(b.size(-1), "nrhs"); auto batch_size = batchCount(lu); diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 1a30187..72d2f65 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -36,10 +36,10 @@ namespace at { namespace native { void geqrf_batched_cublas(const Tensor& input, const Tensor& tau); -void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular); -void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular); +void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular); +void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular); void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos); -void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans); +void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose); #ifdef USE_CUSOLVER @@ -60,7 +60,7 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau); void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors); -void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans); +void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose); void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); diff --git a/test/test_linalg.py b/test/test_linalg.py index 5912111..6ff0ce5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4926,17 +4926,6 @@ class TestLinalg(TestCase): run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b - @onlyCPU - @skipCPUIfNoLapack - @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_triangular_solve_singular(self, device, dtype): - b = torch.rand(3, 1, dtype=dtype, device=device) - A = torch.eye(3, 3, dtype=dtype, device=device) - A[-1, -1] = 0 # Now A is singular - err_str = r"triangular_solve: The diagonal element 3 is zero" - with self.assertRaisesRegex(RuntimeError, err_str): - torch.triangular_solve(b, A) - @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index fe50bad..d0728dd 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -9621,6 +9621,10 @@ with the default keyword arguments. batches of 2D matrices. If the inputs are batches, then returns batched outputs `X` +If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and +:attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned, +the result may contain `NaN` s. + Supports input of float, double, cfloat and cdouble data types. Args: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 00bbeb3..72af4fc 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8156,7 +8156,7 @@ op_db: List[OpInfo] = [ supports_out=False, sample_inputs_func=sample_inputs_legacy_solve, check_batched_gradgrad=False, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagma]), UnaryUfuncInfo('trunc', aliases=('fix', ), ref=np.trunc, -- 2.7.4