extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
-// trtrs
-extern "C" void ztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
-extern "C" void ctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
-extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
-extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
-
// geqrf
extern "C" void zgeqrf_(int *m, int *n, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
extern "C" void cgeqrf_(int *m, int *n, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
float *work, int *lwork, int *info);
#endif
+#if AT_BUILD_WITH_BLAS()
+// trsm
+extern "C" void ztrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *alpha, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb);
+extern "C" void ctrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *alpha, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb);
+extern "C" void dtrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, double *alpha, double *a, int *lda, double *b, int *ldb);
+extern "C" void strsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, float *alpha, float *a, int *lda, float *b, int *ldb);
+#endif
+
namespace at {
namespace native {
spotri_(&uplo, &n, a, &lda, info);
}
-template<> void lapackTriangularSolve<c10::complex<double>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
- ztrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
-}
-
-template<> void lapackTriangularSolve<c10::complex<float>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb, int *info) {
- ctrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
-}
-
-template<> void lapackTriangularSolve<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
- dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
-}
-
-template<> void lapackTriangularSolve<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
- strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
-}
-
template<> void lapackGeqrf<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
zgeqrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
}
}
#endif
+#if AT_BUILD_WITH_BLAS()
+template<> void blasTriangularSolve<c10::complex<double>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb) {
+ std::complex<double> one{1., 0.};
+ ztrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb);
+}
+
+template<> void blasTriangularSolve<c10::complex<float>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb) {
+ std::complex<float> one{1.f, 0.f};
+ ctrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb);
+}
+
+template<> void blasTriangularSolve<double>(char side, char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb) {
+ auto one = 1.;
+ dtrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
+}
+
+template<> void blasTriangularSolve<float>(char side, char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb) {
+ auto one = 1.f;
+ strsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
+}
+#endif
+
// Below of the definitions of the functions operating on a batch that are going to be dispatched
// in the main helper functions for the linear algebra operations
result.copy_(other);
clone_input.copy_(input);
- triangular_solve_stub(input.device().type(), clone_input, result, infos, upper, transpose, /*conjugate_transpose=*/false, unitriangular);
+ triangular_solve_stub(input.device().type(), clone_input, result, /*left=*/true, upper, transpose ? TransposeType::Transpose : TransposeType::NoTranspose, unitriangular);
return std::tuple<Tensor&, Tensor&>(result, clone_input);
}
DEFINE_DISPATCH(lu_solve_trans_stub);
// Supports arbitrary batch dimensions for self and LU_data (implicitly LU_pivots also)
-Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, const c10::string_view trans_str) {
- auto trans = std::toupper(trans_str[0]);
- switch (trans) {
- case 'N':
- case 'T':
- case 'C':
- break;
- default:
- TORCH_CHECK(false,
- "lu_solve: wrong `trans` parameter, it must be one of 'N', 'T' or 'C'");
- }
+Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, TransposeType trans) {
TORCH_CHECK(self.dim() >= 2,
"b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
TORCH_CHECK(LU_data.dim() >= 2,
}
Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) {
- return at::native::_lu_solve_trans(self, LU_data, LU_pivots, "N");
+ return at::native::_lu_solve_trans(self, LU_data, LU_pivots, TransposeType::NoTranspose);
}
Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) {
auto lu_clone = lu.clone();
condition_diagonal(lu_clone);
- auto trans = self.is_complex() ? 'C' : 'T';
+ auto trans = self.is_complex() ? TransposeType::ConjTranspose : TransposeType::Transpose;
// d is modified in-place and will contain the result
lu_solve_trans_stub(self.device().type(), d, lu_clone, pivs, trans);
u.conj_physical_();
}
- auto infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
-
// triangular_solve_stub performs operations in-place.
// Tensor d will contain the result
condition_diagonal(u);
// Since u is conjugated in-place in the code above, it is sufficient
// to just run triangular_solve with upper=false.
triangular_solve_stub(
- self.device().type(), u, d, infos,
+ self.device().type(), u, d,
+ /*left=*/true,
/*upper=*/false,
- /*transpose=*/false,
- /*conjugate_transpose=*/false,
+ /*transpose=*/TransposeType::NoTranspose,
/*unitriangular=*/false);
// After this operation d will contain a row-wise permuted grad wrt to self
// The same notes as for the system involving u apply here.
triangular_solve_stub(
- self.device().type(), l, d, infos,
+ self.device().type(), l, d,
+ /*left=*/true,
/*upper=*/true,
- /*transpose=*/false,
- /*conjugate_transpose=*/false,
+ /*transpose=*/TransposeType::NoTranspose,
/*unitriangular=*/true);
// multiply by p to restore the row order
void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
template <class scalar_t>
-void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb, int* info);
-
-template <class scalar_t>
void lapackGels(char trans, int m, int n, int nrhs,
scalar_t *a, int lda, scalar_t *b, int ldb,
scalar_t *work, int lwork, int *info);
#endif
+#if AT_BUILD_WITH_BLAS()
+template <class scalar_t>
+void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
+#endif
+
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
using triangular_solve_fn = void (*)(
Tensor& /*A*/,
- Tensor& /*b*/,
- Tensor& /*infos*/,
+ Tensor& /*B*/,
+ bool /*left*/,
bool /*upper*/,
- bool /*transpose*/,
- bool /*conjugate_transpose*/,
+ TransposeType /*transpose*/,
bool /*unitriangular*/);
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
const Tensor& /*b*/,
const Tensor& /*lu*/,
const Tensor& /*pivots*/,
- char /*trans*/);
+ TransposeType /*trans*/);
DECLARE_DISPATCH(lu_solve_trans_fn, lu_solve_trans_stub);
namespace at { namespace native {
namespace {
-
/*
Computes the Cholesky decomposition of matrices stored in `input`.
This is an in-place routine and the content of 'input' is overwritten with the result.
and op(A) is one of op(A) = A or op(A) = A^T or op(A) = A^H.
This is an in-place routine, content of 'B' is overwritten.
'upper' controls the portion of input matrix to consider in computations,
-'transpose' if true then op(A) = A^T,
+'transpose' chooses op(A)
'unitriangular' if true then the diagonal elements of A are assumed to be 1
and the actual diagonal values are not used.
-'infos' is an int Tensor containing error codes for each matrix in the batched input.
-For more information see LAPACK's documentation for TRTRS routine.
*/
template<typename scalar_t>
-void apply_triangular_solve(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
-#if !AT_BUILD_WITH_LAPACK()
+void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
+#if !AT_BUILD_WITH_BLAS()
TORCH_CHECK(
false,
"Calling torch.triangular_solve on a CPU tensor requires compiling ",
- "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
+ "PyTorch with BLAS. Please use PyTorch built with BLAS support.");
#else
char uplo = upper ? 'U' : 'L';
- char trans = transpose ? 'T' : 'N';
- trans = conjugate_transpose ? 'C' : trans;
char diag = unitriangular ? 'U' : 'N';
+ char side = left ? 'L' : 'R';
+ const char trans = to_blas(transpose);
auto A_data = A.data_ptr<scalar_t>();
auto B_data = B.data_ptr<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto B_mat_stride = matrixStride(B);
auto batch_size = batchCount(A);
- auto n = A.size(-2);
- auto nrhs = B.size(-1);
- auto lda = std::max<int64_t>(1, n);
- auto infos_data = infos.data_ptr<int>();
+ // This allows to pass rectangular A and B when left = True
+ auto m = left ? A.size(-1) : B.size(-2);
+ auto n = B.size(-1);
+ auto lda = std::max<int64_t>(1, A.size(-2));
+ auto ldb = std::max<int64_t>(1, B.size(-2));
for (const auto i : c10::irange(batch_size)) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
- int* info_working_ptr = &infos_data[i];
- lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, lda, B_working_ptr, lda, info_working_ptr);
- // The current behaviour for linear algebra functions to raise an error if something goes wrong
- // or input doesn't satisfy some requirement
- // therefore return early since further computations will be wasted anyway
- if (*info_working_ptr != 0) {
- return;
- }
+ blasTriangularSolve<scalar_t>(side, uplo, trans, diag, m, n, A_working_ptr, lda, B_working_ptr, ldb);
}
#endif
}
-void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cpu", [&]{
- apply_triangular_solve<scalar_t>(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+ apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
For further details, please see the LAPACK documentation for GETRS.
*/
template <typename scalar_t>
-void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans = 'N') {
+void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#if !AT_BUILD_WITH_LAPACK()
TORCH_CHECK(
false,
#else
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
+ const auto trans = to_blas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto b_stride = matrixStride(b);
auto lu_stride = matrixStride(lu);
}
// This is a type dispatching helper function for 'apply_lu_solve'
-void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) {
- switch (trans) {
- case 'N':
- case 'T':
- case 'C':
- break;
- default:
- TORCH_INTERNAL_ASSERT(false, "lu_solve_trans_cpu: wrong value for `trans`, it must be one of 'N', 'T', 'C'");
- }
+void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cpu", [&]{
apply_lu_solve<scalar_t>(b, lu, pivots, trans);
});
}
void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
- lu_solve_trans_kernel(b, lu, pivots, 'N');
+ lu_solve_trans_kernel(b, lu, pivots, TransposeType::NoTranspose);
}
} // anonymous namespace
*ldc = m;
}
- if(transa != NoTranspose) {
+ if(transa != TransposeType::NoTranspose) {
if (m == 1) {
*lda = k;
}
*lda = m;
}
- if(transb != NoTranspose) {
+ if(transb != TransposeType::NoTranspose) {
if (k == 1) {
*ldb = n;
}
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t &lda, int64_t &ldb, int64_t &ldc) {
- const bool transa_ = transa != NoTranspose;
- const bool transb_ = transb != NoTranspose;
+ const bool transa_ = transa != TransposeType::NoTranspose;
+ const bool transb_ = transb != TransposeType::NoTranspose;
return (
(m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
(lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) &&
(ldc >= std::max(int64_t{1}, m)));
}
-#if AT_BUILD_WITH_BLAS()
-char to_blas(TransposeType trans) {
- switch (trans) {
- case Transpose: return 't';
- case NoTranspose: return 'n';
- case ConjTranspose: return 'c';
- }
- TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
-}
-#endif // AT_BUILD_WITH_BLAS
-
#ifdef USE_FBGEMM
fbgemm::matrix_op_t to_fbgemm(TransposeType trans) {
switch (trans) {
- case Transpose: return fbgemm::matrix_op_t::Transpose;
- case NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
- case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
+ case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose;
+ case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
+ case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
}
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
}
#pragma once
#include <ATen/native/DispatchStub.h>
+#include <ATen/native/LinearAlgebraUtils.h>
#include <c10/util/complex.h>
#include <c10/core/ScalarType.h>
#include <c10/core/Scalar.h>
namespace native {
namespace cpublas {
-enum TransposeType {
- Transpose,
- NoTranspose,
- ConjTranspose,
-};
-
namespace internal {
void normalize_last_dims(
TransposeType transa, TransposeType transb,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
at::native::cpublas::gemm(
- transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
- transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
+ transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
+ transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
m, n, k,
alpha.to<scalar_t>(),
a.data_ptr<scalar_t>(), lda,
namespace at { namespace native {
+// Used as an interface between the different BLAS-like libraries
+enum class TransposeType {
+ NoTranspose,
+ Transpose,
+ ConjTranspose,
+};
+
+// Transforms TransposeType into the BLAS / LAPACK format
+static char to_blas(TransposeType trans) {
+ switch (trans) {
+ case TransposeType::Transpose: return 'T';
+ case TransposeType::NoTranspose: return 'N';
+ case TransposeType::ConjTranspose: return 'C';
+ }
+ TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
/*
* Clones a Tensor so that the following conditions hold:
* If we think of a Tensor of having size (B, M, N), where B is any number
* - Each (M, N) matrix is in column major form
* - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
* Then when laid out in memory, the M by N matrix starting at
- * P.data_ptr()[b * M * N] is of the same corresponding batch as the M' by N'
- * matrix starting at Q.data_ptr()[b * M' * N'].
+ * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
+ * matrix starting at Q.data_ptr()[B * M' * N'].
*/
static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
// If src is already in batched column major format, then
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
cpublas::gemm(
- cpublas::NoTranspose,
- cpublas::Transpose,
+ TransposeType::NoTranspose,
+ TransposeType::Transpose,
n,
m,
k,
// column-major matrices)
if (bias.defined()) {
cpublas::gemm(
- cpublas::Transpose,
- cpublas::NoTranspose,
+ TransposeType::Transpose,
+ TransposeType::NoTranspose,
n_,
m_,
k_,
? grad_columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
cpublas::gemm(
- cpublas::NoTranspose,
- cpublas::NoTranspose,
+ TransposeType::NoTranspose,
+ TransposeType::NoTranspose,
n,
m,
k,
? columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
cpublas::gemm(
- cpublas::Transpose,
- cpublas::NoTranspose,
+ TransposeType::Transpose,
+ TransposeType::NoTranspose,
n,
m,
k,
// Do GEMM (note: this is a bit confusing because gemm assumes
// column-major matrices)
cpublas::gemm(
- cpublas::NoTranspose,
- cpublas::Transpose,
+ TransposeType::NoTranspose,
+ TransposeType::Transpose,
n,
m,
k,
// column-major matrices)
if (bias.defined()) {
cpublas::gemm(
- cpublas::Transpose,
- cpublas::NoTranspose,
+ TransposeType::Transpose,
+ TransposeType::NoTranspose,
n_,
m_,
k_,
? grad_columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
cpublas::gemm(
- cpublas::NoTranspose,
- cpublas::NoTranspose,
+ TransposeType::NoTranspose,
+ TransposeType::NoTranspose,
n,
m,
k,
? columns.data_ptr<scalar_t>()
: grad_output_n.data_ptr<scalar_t>();
cpublas::gemm(
- cpublas::Transpose,
- cpublas::NoTranspose,
+ TransposeType::Transpose,
+ TransposeType::NoTranspose,
n,
m,
k,
-
-
#include <ATen/ATen.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/DilatedConvolutionUtils.h>
op(A) = 'n', op(B) = 'n', alpha=1, beta=1
*/
cpublas::gemm(
- /*transa=*/cpublas::NoTranspose,
- /*transb=*/cpublas::NoTranspose,
+ /*transa=*/TransposeType::NoTranspose,
+ /*transb=*/TransposeType::NoTranspose,
/* m=*/columns.size(1),
/* n=*/nOutputPlane,
/* k=*/columns.size(0),
op(A) = 'n', op(B) = 't', alpha=1, beta=0
*/
cpublas::gemm(
- /*transa=*/cpublas::NoTranspose,
- /*transb=*/cpublas::Transpose,
+ /*transa=*/TransposeType::NoTranspose,
+ /*transb=*/TransposeType::Transpose,
/* m=*/columns.size(1),
/* n=*/columns.size(0),
/* k=*/nOutputPlane,
op(B) = 'n', alpha=scale, beta=1
*/
cpublas::gemm(
- /*transa=*/cpublas::Transpose,
- /*transb=*/cpublas::NoTranspose,
+ /*transa=*/TransposeType::Transpose,
+ /*transb=*/TransposeType::NoTranspose,
/* m=*/columns.size(0),
/* n=*/nOutputPlane,
/* k=*/columns.size(1),
const scalar_t *b, int64_t ldb,
scalar_t beta,
scalar_t *c, int64_t ldc) {
- if(transa == NoTranspose && transb == NoTranspose) {
+ if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
- } else if(transa == Transpose && transb != Transpose) {
+ } else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) {
gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
- } else if(transa == NoTranspose && transb == Transpose) {
+ } else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) {
gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
- } else { // transa == Transpose && transb == Transpose
+ } else { // transa == TransposeType::Transpose && transb == TransposeType::Transpose
gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
template<class scalar_t>
void magmaTriangularSolveBatched(
- magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+ magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue);
template<>
void magmaTriangularSolveBatched<double>(
- magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+ magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
- magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
+ magmablas_dtrsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}
template<>
void magmaTriangularSolveBatched<float>(
- magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+ magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
- magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
+ magmablas_strsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
}
template<>
void magmaTriangularSolveBatched<c10::complex<double>>(
- magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+ magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
c10::complex<double>** dA_array, magma_int_t ldda, c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmaDoubleComplex alpha({1, 0});
- magmablas_ztrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha,
+ magmablas_ztrsm_batched(side, uplo, trans, diag, m, n, alpha,
reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda,
reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
template<>
void magmaTriangularSolveBatched<c10::complex<float>>(
- magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+ magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
c10::complex<float>** dA_array, magma_int_t ldda, c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
magmaFloatComplex alpha({1, 0});
- magmablas_ctrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha,
+ magmablas_ctrsm_batched(side, uplo, trans, diag, m, n, alpha,
reinterpret_cast<magmaFloatComplex**>(dA_array), ldda,
reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
AT_CUDA_CHECK(cudaGetLastError());
", when calling ", magma_function_name);
}
+magma_trans_t to_magma(TransposeType trans) {
+ switch (trans) {
+ case TransposeType::NoTranspose: return MagmaNoTrans;
+ case TransposeType::Transpose: return MagmaTrans;
+ case TransposeType::ConjTranspose: return MagmaConjTrans;
+ }
+ TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
} // anonymous namespace
-
#endif // USE_MAGMA
#define ALLOCATE_ARRAY(name, type, size) \
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
-static void apply_triangular_solve_batched(Tensor& A, Tensor& b, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve_batched_magma(Tensor& A, Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) {
#ifndef USE_MAGMA
AT_ERROR("triangular_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
- magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans;
- trans = conjugate_transpose ? MagmaConjTrans : trans;
+ magma_trans_t trans = to_magma(transpose);
magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit;
+ magma_side_t side = left ? MagmaLeft : MagmaRight;
auto A_data = A.data_ptr<scalar_t>();
auto b_data = b.data_ptr<scalar_t>();
- magma_int_t m = magma_int_cast(A.size(-2), "A.size(-2)");
- magma_int_t n = magma_int_cast(A.size(-1), "A.size(-1)");
- magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
+ // This allows to pass rectangular A and b when left = True
+ magma_int_t m = magma_int_cast(left ? A.size(-1) : b.size(-2), "m");
+ magma_int_t n = magma_int_cast(b.size(-1), "n");
// magma returns early if m <= 0 || n <= 0 for magmaTriangularSolveBatched
// magmaTriangularSolve is calling cuBLAS and it prints
// ** On entry to DTRSM parameter number 9 had an illegal value
// so let's use proper lda parameter here
- magma_int_t lda = std::max<magma_int_t>(1, m);
- magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
+ magma_int_t lda = std::max<magma_int_t>(1, A.size(-2));
+ magma_int_t ldb = std::max<magma_int_t>(1, b.size(-2));
+ magma_int_t batch_size = magma_int_cast(batchCount(A), "batch_size");
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
MAGMAQueue magma_queue(b.get_device());
constexpr int64_t batch_limit = 65535;
- // Compute as many batches of 65535 possible
+ // Compute as many batches of 65535 as possible
// The number of "mini"-batches are floor(batch_size / batch_limit)
// and these cover floor(batch_size / batch_limit) * batch_limit matrix solves
int64_t mini_batches = batch_size / batch_limit;
scalar_t** b_array_cur = &b_array[mini_idx];
magmaTriangularSolveBatched<scalar_t>(
- uplo, trans, diag, n, nrhs, A_array_cur,
- lda, b_array_cur, lda, batch_limit, magma_queue);
+ side, uplo, trans, diag, m, n, A_array_cur,
+ lda, b_array_cur, ldb, batch_limit, magma_queue);
}
// Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
// which concisely is equal to batch_size % batch_limit
if (batch_size % batch_limit != 0) {
magmaTriangularSolveBatched<scalar_t>(
- uplo, trans, diag, n, nrhs, &A_array[mini_idx],
- lda, &b_array[mini_idx], lda, batch_size % batch_limit, magma_queue);
+ side, uplo, trans, diag, m, n, &A_array[mini_idx],
+ lda, &b_array[mini_idx], ldb, batch_size % batch_limit, magma_queue);
}
#endif
}
-void triangular_solve_batched_magma(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
- (void)infos; // unused
+void triangular_solve_batched_magma(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
- apply_triangular_solve_batched<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+ apply_triangular_solve_batched_magma<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
-void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
// For batches smaller than 8 and matrix sizes larger than 64x64 cuBLAS forloop is faster than batched version
if (batchCount(A) <= 8 && A.size(-1) >= 64) {
- triangular_solve_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+ triangular_solve_cublas(A, B, left, upper, transpose, unitriangular);
} else {
#ifndef USE_MAGMA
- triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+ triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular);
#else
// cuBLAS batched is faster than MAGMA batched up until 512x512, after that MAGMA is faster
if (A.size(-1) <= 512) {
- triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+ triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular);
} else {
- triangular_solve_batched_magma(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+ triangular_solve_batched_magma(A, B, left, upper, transpose, unitriangular);
}
#endif // USE_MAGMA
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-#ifdef USE_MAGMA
-magma_trans_t _get_magma_trans(char trans) {
- switch (trans) {
- case 'N':
- return MagmaNoTrans;
- case 'T':
- return MagmaTrans;
- case 'C':
- return MagmaConjTrans;
- default:
- return MagmaNoTrans;
- }
-}
-#endif
-
/*
Solves the matrix equation A X = B
X and B are n-by-nrhs matrices, A is represented using the LU factorization.
For further details, please see the MAGMA documentation for magma_dgetrs_gpu.
*/
template <typename scalar_t>
-static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#ifndef USE_MAGMA
TORCH_CHECK(
false,
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. lease rebuild with MAGMA.");
#else
- auto trans = _get_magma_trans(lapack_trans);
+ auto trans = to_magma(transpose);
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
For further details, please see the MAGMA documentation for magma_dgetrs_batched.
*/
template <typename scalar_t>
-static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#ifndef USE_MAGMA
TORCH_CHECK(
false,
"Calling torch.lu_solve on a CUDA tensor requires compiling ",
"PyTorch with MAGMA. lease rebuild with MAGMA.");
#else
- auto trans = _get_magma_trans(lapack_trans);
+ auto trans = to_magma(transpose);
auto b_data = b.data_ptr<scalar_t>();
auto lu_data = lu.data_ptr<scalar_t>();
#endif
}
-static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_batched_magma", [&]{
- apply_lu_solve_batched_magma<scalar_t>(b, lu, pivots, lapack_trans);
+ apply_lu_solve_batched_magma<scalar_t>(b, lu, pivots, trans);
});
}
-static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_looped_magma", [&]{
- apply_lu_solve_looped_magma<scalar_t>(b, lu, pivots, lapack_trans);
+ apply_lu_solve_looped_magma<scalar_t>(b, lu, pivots, trans);
});
}
-#if defined(USE_CUSOLVER) || defined(CUDART_VERSION)
-cublasOperation_t _get_cublas_trans(char trans) {
- switch (trans) {
- case 'N':
- return CUBLAS_OP_N;
- case 'T':
- return CUBLAS_OP_T;
- case 'C':
- return CUBLAS_OP_C;
- default:
- return CUBLAS_OP_N;
- }
-}
-#endif
-static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) {
+static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
auto batch_size = batchCount(lu);
auto m = lu.size(-2);
auto b2 = b.size(-1);
// heuristics determined from tests dicussed in https://github.com/pytorch/pytorch/pull/59148
#ifdef USE_CUSOLVER
if ((batch_size == 1 && m > 512) || (batch_size <= 8 && over_magma_dim_limit)) {
- lu_solve_looped_cusolver(b, lu, pivots, _get_cublas_trans(trans));
+ lu_solve_looped_cusolver(b, lu, pivots, trans);
}
#else
if (batch_size == 1) {
#endif // ifdef USE_CUSOLVER
#ifdef CUDART_VERSION
else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) {
- lu_solve_batched_cublas(b, lu, pivots, _get_cublas_trans(trans));
+ lu_solve_batched_cublas(b, lu, pivots, trans);
}
#endif // ifdef CUDART_VERSION
else {
REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_dispatch);
static void lu_solve_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
- lu_solve_trans_dispatch(b, lu, pivots, 'N');
+ lu_solve_trans_dispatch(b, lu, pivots, TransposeType::NoTranspose);
}
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_dispatch);
});
}
-void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) {
+void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/) {
// The steps for using the QR decomposition for solving least squares problems
// are outlined here https://en.wikipedia.org/wiki/QR_decomposition#Using_for_solution_to_linear_inverse_problems
auto m = A.size(-2);
ormqr_kernel(A_broadcasted, tau_broadcasted, B, /*left=*/true, /*transpose=*/true);
// Step 3: solve R X = B
- bool upper = true;
- bool transpose = false;
- bool conjugate_transpose = false;
- bool unitriangular = false;
triangular_solve_kernel(
const_cast<Tensor&>(A_broadcasted),
const_cast<Tensor&>(B),
- const_cast<Tensor&>(infos),
- upper, transpose, conjugate_transpose, unitriangular);
+ /*left=*/true,
+ /*upper=*/true,
+ /*transpose=*/TransposeType::NoTranspose,
+ /*unitriangular=*/false);
} else { // underdetermined case
Tensor Ah = cloneBatchedColumnMajor(A.conj().transpose(-2, -1));
Tensor Ah_broadcasted = is_fortran_contiguous ? Ah_expanded : cloneBatchedColumnMajor(Ah_expanded);
// Step 2: R^H Z = B
- bool upper = true;
- bool transpose = true;
- bool conjugate_transpose = true;
- bool unitriangular = false;
+ const auto trans = Ah_broadcasted.is_complex() ? TransposeType::ConjTranspose
+ : TransposeType::Transpose;
triangular_solve_kernel(
const_cast<Tensor&>(Ah_broadcasted),
const_cast<Tensor&>(B),
- const_cast<Tensor&>(infos),
- upper, transpose, conjugate_transpose, unitriangular);
+ /*left=*/true,
+ /*upper=*/true,
+ /*transpose=*/trans,
+ /*unitriangular=*/false);
// B matrix has the size max(m, n) x nrhs
// triangular_solve_kernel writes its output into the first m rows of B leaving the rest untouched
namespace at {
namespace native {
+cublasOperation_t to_cublas(TransposeType trans) {
+ switch (trans) {
+ case TransposeType::NoTranspose: return CUBLAS_OP_N;
+ case TransposeType::Transpose: return CUBLAS_OP_T;
+ case TransposeType::ConjTranspose: return CUBLAS_OP_C;
+ }
+ TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
// Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
// 'input' must be a contiguous tensor
template <typename scalar_t>
}
template <typename scalar_t>
-static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
#ifndef CUDART_VERSION
TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.")
#else
+ const auto trans = to_cublas(transpose);
auto pivots_data = pivots.data_ptr<int>();
auto batch_size = cuda_int_cast(batchCount(lu), "batch_size");;
#endif
}
-void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(lu.scalar_type(), "lu_solve_cublas", [&]{
apply_lu_solve_batched_cublas<scalar_t>(b, lu, pivots, trans);
});
}
template <typename scalar_t>
-static void apply_triangular_solve(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
- cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
- trans = conjugate_transpose ? CUBLAS_OP_C : trans;
+ const auto trans = to_cublas(transpose);
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
- cublasSideMode_t side = CUBLAS_SIDE_LEFT;
+ cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto A_data = A.data_ptr<scalar_t>();
auto B_data = B.data_ptr<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto B_mat_stride = matrixStride(B);
auto batch_size = batchCount(A);
- auto m = cuda_int_cast(A.size(-2), "m");
- auto n = cuda_int_cast(A.size(-1), "n");
- auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
- auto lda = std::max<int>(1, m);
+ // This allows to pass rectangular A and B when left = True
+ auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
+ auto n = cuda_int_cast(B.size(-1), "n");
+ auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
+ auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
auto handle = at::cuda::getCurrentCUDABlasHandle();
- at::cuda::blas::trsm(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_working_ptr, lda, B_working_ptr, lda);
+ at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
}
}
-void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
- (void)infos; // unused
+void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
- apply_triangular_solve<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+ apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
template <typename scalar_t>
-static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
- cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
- trans = conjugate_transpose ? CUBLAS_OP_C : trans;
+ const auto trans = to_cublas(transpose);
cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
- cublasSideMode_t side = CUBLAS_SIDE_LEFT;
+ cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
- auto m = cuda_int_cast(A.size(-2), "m");
- auto n = cuda_int_cast(A.size(-1), "n");
- auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
- auto lda = std::max<int>(1, m);
+ // This allows to pass rectangular A and B when left = True
+ auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
+ auto n = cuda_int_cast(B.size(-1), "n");
+ auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
+ auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
auto alpha = scalar_t{1};
auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
auto handle = at::cuda::getCurrentCUDABlasHandle();
- at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_ptr_array_data, lda, B_ptr_array_data, lda, batch_size);
+ at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
}
-void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
- (void)infos; // unused
+void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
- apply_triangular_solve_batched<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+ apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
});
}
}
}
-void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cusolver", [&] {
+ const auto trans = to_cublas(transpose);
int n = cuda_int_cast(lu.size(-2), "n");
int nrhs = cuda_int_cast(b.size(-1), "nrhs");
auto batch_size = batchCount(lu);
namespace native {
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau);
-void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular);
-void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular);
+void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
+void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos);
-void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans);
+void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose);
#ifdef USE_CUSOLVER
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
-void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans);
+void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose);
void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b
- @onlyCPU
- @skipCPUIfNoLapack
- @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
- def test_triangular_solve_singular(self, device, dtype):
- b = torch.rand(3, 1, dtype=dtype, device=device)
- A = torch.eye(3, 3, dtype=dtype, device=device)
- A[-1, -1] = 0 # Now A is singular
- err_str = r"triangular_solve: The diagonal element 3 is zero"
- with self.assertRaisesRegex(RuntimeError, err_str):
- torch.triangular_solve(b, A)
-
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
batches of 2D matrices. If the inputs are batches, then returns
batched outputs `X`
+If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and
+:attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned,
+the result may contain `NaN` s.
+
Supports input of float, double, cfloat and cdouble data types.
Args:
supports_out=False,
sample_inputs_func=sample_inputs_legacy_solve,
check_batched_gradgrad=False,
- decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
+ decorators=[skipCUDAIfNoMagma]),
UnaryUfuncInfo('trunc',
aliases=('fix', ),
ref=np.trunc,