Use trsm for triangular_solve in CPU (#63567)
authorlezcano <lezcano-93@hotmail.com>
Wed, 8 Sep 2021 00:22:49 +0000 (17:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 00:26:17 +0000 (17:26 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63567

The current implementation called trtrs for CPU and trsm for CUDA.
See https://github.com/pytorch/pytorch/issues/56326#issuecomment-825496115 for a discussion on the differences between
these two functions and why we prefer trsm vs trtrs on CUDA.

This PR also exposes the `side` argument of this function which is used
in the second PR of this stack to optimise the number copies one needs to make
when preparing the arguments to be sent to the backends.

It also changes the use of `bool`s to a common enum type to represent
whether a matrix is transposed / conj transposed, etc. This makes the API
consistent, as before, the behaviour of these functions with `transpose=True`
and `conjugate_transpose=True` it was not well defined.
Functions to transform this type into the specific types / chars for the different
libraries are provided under the names `to_blas`, `to_lapack`, `to_magma`, etc.

This is the first of a stack of PRs that aim to improve the performance of
`linalg.solve_triangular`. `trsm` has an extra parameter (`side`), which allows to
ellide the copy of the triangular matrix in many cases.

Fixes https://github.com/pytorch/pytorch/issues/56326

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D30566479

Pulled By: mruberry

fbshipit-source-id: 3831af9b51e09fbfe272c17c88c21ecf45413212

17 files changed:
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/BatchLinearAlgebra.h
aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
aten/src/ATen/native/CPUBlas.cpp
aten/src/ATen/native/CPUBlas.h
aten/src/ATen/native/LinearAlgebra.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp
aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp
aten/src/ATen/native/NaiveDilatedConvolution.cpp
aten/src/ATen/native/cpu/BlasKernel.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebra.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h
test/test_linalg.py
torch/_torch_docs.py
torch/testing/_internal/common_methods_invocations.py

index 498b51b..0471cdb 100644 (file)
@@ -55,12 +55,6 @@ extern "C" void cpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, in
 extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
 extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
 
-// trtrs
-extern "C" void ztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
-extern "C" void ctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
-extern "C" void dtrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
-extern "C" void strtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
-
 // geqrf
 extern "C" void zgeqrf_(int *m, int *n, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
 extern "C" void cgeqrf_(int *m, int *n, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
@@ -200,6 +194,14 @@ extern "C" void sgelss_(int *m, int *n, int *nrhs,
     float *work, int *lwork, int *info);
 #endif
 
+#if AT_BUILD_WITH_BLAS()
+// trsm
+extern "C" void ztrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *alpha, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb);
+extern "C" void ctrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *alpha, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb);
+extern "C" void dtrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, double *alpha, double *a, int *lda, double *b, int *ldb);
+extern "C" void strsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, float *alpha, float *a, int *lda, float *b, int *ldb);
+#endif
+
 namespace at {
 namespace native {
 
@@ -318,22 +320,6 @@ template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda
   spotri_(&uplo, &n, a, &lda, info);
 }
 
-template<> void lapackTriangularSolve<c10::complex<double>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
-  ztrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
-}
-
-template<> void lapackTriangularSolve<c10::complex<float>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb, int *info) {
-  ctrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
-}
-
-template<> void lapackTriangularSolve<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
-  dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
-}
-
-template<> void lapackTriangularSolve<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
-  strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
-}
-
 template<> void lapackGeqrf<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
   zgeqrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
 }
@@ -687,6 +673,28 @@ template<> void lapackGelss<float>(
 }
 #endif
 
+#if AT_BUILD_WITH_BLAS()
+template<> void blasTriangularSolve<c10::complex<double>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb) {
+  std::complex<double> one{1., 0.};
+  ztrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb);
+}
+
+template<> void blasTriangularSolve<c10::complex<float>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb) {
+  std::complex<float> one{1.f, 0.f};
+  ctrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb);
+}
+
+template<> void blasTriangularSolve<double>(char side, char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb) {
+  auto one = 1.;
+  dtrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
+}
+
+template<> void blasTriangularSolve<float>(char side, char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb) {
+  auto one = 1.f;
+  strsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
+}
+#endif
+
 // Below of the definitions of the functions operating on a batch that are going to be dispatched
 // in the main helper functions for the linear algebra operations
 
@@ -1635,7 +1643,7 @@ static std::tuple<Tensor&, Tensor&> triangular_solve_out_info(
   result.copy_(other);
   clone_input.copy_(input);
 
-  triangular_solve_stub(input.device().type(), clone_input, result, infos, upper, transpose, /*conjugate_transpose=*/false, unitriangular);
+  triangular_solve_stub(input.device().type(), clone_input, result, /*left=*/true, upper, transpose ? TransposeType::Transpose : TransposeType::NoTranspose, unitriangular);
 
   return std::tuple<Tensor&, Tensor&>(result, clone_input);
 }
@@ -3545,17 +3553,7 @@ DEFINE_DISPATCH(lu_solve_stub);
 DEFINE_DISPATCH(lu_solve_trans_stub);
 
 // Supports arbitrary batch dimensions for self and LU_data (implicitly LU_pivots also)
-Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, const c10::string_view trans_str) {
-  auto trans = std::toupper(trans_str[0]);
-  switch (trans) {
-    case 'N':
-    case 'T':
-    case 'C':
-      break;
-    default:
-      TORCH_CHECK(false,
-                  "lu_solve: wrong `trans` parameter, it must be one of 'N', 'T' or 'C'");
-  }
+Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, TransposeType trans) {
   TORCH_CHECK(self.dim() >= 2,
               "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
   TORCH_CHECK(LU_data.dim() >= 2,
@@ -3598,7 +3596,7 @@ Tensor _lu_solve_trans(const Tensor& self, const Tensor& LU_data, const Tensor&
 }
 
 Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) {
-  return at::native::_lu_solve_trans(self, LU_data, LU_pivots, "N");
+  return at::native::_lu_solve_trans(self, LU_data, LU_pivots, TransposeType::NoTranspose);
 }
 
 Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) {
@@ -3748,7 +3746,7 @@ Tensor _det_lu_based_helper_backward_helper(
     auto lu_clone = lu.clone();
     condition_diagonal(lu_clone);
 
-    auto trans = self.is_complex() ? 'C' : 'T';
+    auto trans = self.is_complex() ? TransposeType::ConjTranspose : TransposeType::Transpose;
 
     // d is modified in-place and will contain the result
     lu_solve_trans_stub(self.device().type(), d, lu_clone, pivs, trans);
@@ -3766,8 +3764,6 @@ Tensor _det_lu_based_helper_backward_helper(
       u.conj_physical_();
     }
 
-    auto infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
-
     // triangular_solve_stub performs operations in-place.
     // Tensor d will contain the result
     condition_diagonal(u);
@@ -3780,19 +3776,19 @@ Tensor _det_lu_based_helper_backward_helper(
     // Since u is conjugated in-place in the code above, it is sufficient
     // to just run triangular_solve with upper=false.
     triangular_solve_stub(
-      self.device().type(), u, d, infos,
+      self.device().type(), u, d,
+      /*left=*/true,
       /*upper=*/false,
-      /*transpose=*/false,
-      /*conjugate_transpose=*/false,
+      /*transpose=*/TransposeType::NoTranspose,
       /*unitriangular=*/false);
 
     // After this operation d will contain a row-wise permuted grad wrt to self
     // The same notes as for the system involving u apply here.
     triangular_solve_stub(
-      self.device().type(), l, d, infos,
+      self.device().type(), l, d,
+      /*left=*/true,
       /*upper=*/true,
-      /*transpose=*/false,
-      /*conjugate_transpose=*/false,
+      /*transpose=*/TransposeType::NoTranspose,
       /*unitriangular=*/true);
 
     // multiply by p to restore the row order
index 1d239b9..dba3a41 100644 (file)
@@ -37,9 +37,6 @@ template <class scalar_t, class value_t = scalar_t>
 void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
 
 template <class scalar_t>
-void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb, int* info);
-
-template <class scalar_t>
 void lapackGels(char trans, int m, int n, int nrhs,
     scalar_t *a, int lda, scalar_t *b, int ldb,
     scalar_t *work, int lwork, int *info);
@@ -166,6 +163,11 @@ void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
 
 #endif
 
+#if AT_BUILD_WITH_BLAS()
+template <class scalar_t>
+void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
+#endif
+
 using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
 DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
 
@@ -210,11 +212,10 @@ DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
 
 using triangular_solve_fn = void (*)(
     Tensor& /*A*/,
-    Tensor& /*b*/,
-    Tensor& /*infos*/,
+    Tensor& /*B*/,
+    bool /*left*/,
     bool /*upper*/,
-    bool /*transpose*/,
-    bool /*conjugate_transpose*/,
+    TransposeType /*transpose*/,
     bool /*unitriangular*/);
 DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
 
@@ -235,7 +236,7 @@ using lu_solve_trans_fn = void (*)(
     const Tensor& /*b*/,
     const Tensor& /*lu*/,
     const Tensor& /*pivots*/,
-    char /*trans*/);
+    TransposeType /*trans*/);
 DECLARE_DISPATCH(lu_solve_trans_fn, lu_solve_trans_stub);
 
 
index db67e93..dee835a 100644 (file)
@@ -10,7 +10,6 @@
 namespace at { namespace native {
 
 namespace {
-
 /*
   Computes the Cholesky decomposition of matrices stored in `input`.
   This is an in-place routine and the content of 'input' is overwritten with the result.
@@ -791,53 +790,45 @@ X and B are n-by-nrhs matrices, A is a unit, or non-unit, upper or lower triangu
 and op(A) is one of op(A) = A or op(A) = A^T or op(A) = A^H.
 This is an in-place routine, content of 'B' is overwritten.
 'upper' controls the portion of input matrix to consider in computations,
-'transpose' if true then op(A) = A^T,
+'transpose' chooses op(A)
 'unitriangular' if true then the diagonal elements of A are assumed to be 1
 and the actual diagonal values are not used.
-'infos' is an int Tensor containing error codes for each matrix in the batched input.
-For more information see LAPACK's documentation for TRTRS routine.
 */
 template<typename scalar_t>
-void apply_triangular_solve(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
-#if !AT_BUILD_WITH_LAPACK()
+void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
+#if !AT_BUILD_WITH_BLAS()
   TORCH_CHECK(
       false,
       "Calling torch.triangular_solve on a CPU tensor requires compiling ",
-      "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
+      "PyTorch with BLAS. Please use PyTorch built with BLAS support.");
 #else
   char uplo = upper ? 'U' : 'L';
-  char trans = transpose ? 'T' : 'N';
-  trans = conjugate_transpose ? 'C' : trans;
   char diag = unitriangular ? 'U' : 'N';
+  char side = left ? 'L' : 'R';
+  const char trans = to_blas(transpose);
 
   auto A_data = A.data_ptr<scalar_t>();
   auto B_data = B.data_ptr<scalar_t>();
   auto A_mat_stride = matrixStride(A);
   auto B_mat_stride = matrixStride(B);
   auto batch_size = batchCount(A);
-  auto n = A.size(-2);
-  auto nrhs = B.size(-1);
-  auto lda = std::max<int64_t>(1, n);
-  auto infos_data = infos.data_ptr<int>();
+  // This allows to pass rectangular A and B when left = True
+  auto m = left ? A.size(-1) : B.size(-2);
+  auto n = B.size(-1);
+  auto lda = std::max<int64_t>(1, A.size(-2));
+  auto ldb = std::max<int64_t>(1, B.size(-2));
 
   for (const auto i : c10::irange(batch_size)) {
     scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
     scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
-    int* info_working_ptr = &infos_data[i];
-    lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, lda, B_working_ptr, lda, info_working_ptr);
-    // The current behaviour for linear algebra functions to raise an error if something goes wrong
-    // or input doesn't satisfy some requirement
-    // therefore return early since further computations will be wasted anyway
-    if (*info_working_ptr != 0) {
-      return;
-    }
+    blasTriangularSolve<scalar_t>(side, uplo, trans, diag, m, n, A_working_ptr, lda, B_working_ptr, ldb);
   }
 #endif
 }
 
-void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cpu", [&]{
-    apply_triangular_solve<scalar_t>(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+    apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
   });
 }
 
@@ -904,7 +895,7 @@ void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, b
   For further details, please see the LAPACK documentation for GETRS.
 */
 template <typename scalar_t>
-void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans = 'N') {
+void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
 #if !AT_BUILD_WITH_LAPACK()
   TORCH_CHECK(
       false,
@@ -913,6 +904,7 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, cha
 #else
   auto b_data = b.data_ptr<scalar_t>();
   auto lu_data = lu.data_ptr<scalar_t>();
+  const auto trans = to_blas(transpose);
   auto pivots_data = pivots.data_ptr<int>();
   auto b_stride = matrixStride(b);
   auto lu_stride = matrixStride(lu);
@@ -940,22 +932,14 @@ void apply_lu_solve(const Tensor& b, const Tensor& lu, const Tensor& pivots, cha
 }
 
 // This is a type dispatching helper function for 'apply_lu_solve'
-void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) {
-  switch (trans) {
-    case 'N':
-    case 'T':
-    case 'C':
-      break;
-    default:
-      TORCH_INTERNAL_ASSERT(false, "lu_solve_trans_cpu: wrong value for `trans`, it must be one of 'N', 'T', 'C'");
-  }
+void lu_solve_trans_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cpu", [&]{
     apply_lu_solve<scalar_t>(b, lu, pivots, trans);
   });
 }
 
 void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
-  lu_solve_trans_kernel(b, lu, pivots, 'N');
+  lu_solve_trans_kernel(b, lu, pivots, TransposeType::NoTranspose);
 }
 
 } // anonymous namespace
index f14e4dc..9932aa0 100644 (file)
@@ -39,7 +39,7 @@ void normalize_last_dims(
     *ldc = m;
   }
 
-  if(transa != NoTranspose) {
+  if(transa != TransposeType::NoTranspose) {
     if (m == 1) {
       *lda = k;
     }
@@ -47,7 +47,7 @@ void normalize_last_dims(
     *lda = m;
   }
 
-  if(transb != NoTranspose) {
+  if(transb != TransposeType::NoTranspose) {
     if (k == 1) {
       *ldb = n;
     }
@@ -63,8 +63,8 @@ bool use_blas_gemm(
     TransposeType transa, TransposeType transb,
     int64_t m, int64_t n, int64_t k,
     int64_t &lda, int64_t &ldb, int64_t &ldc) {
-  const bool transa_ = transa != NoTranspose;
-  const bool transb_ = transb != NoTranspose;
+  const bool transa_ = transa != TransposeType::NoTranspose;
+  const bool transb_ = transb != TransposeType::NoTranspose;
   return (
       (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
       (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) &&
@@ -73,23 +73,12 @@ bool use_blas_gemm(
       (ldc >= std::max(int64_t{1}, m)));
 }
 
-#if AT_BUILD_WITH_BLAS()
-char to_blas(TransposeType trans) {
-  switch (trans) {
-  case Transpose: return 't';
-  case NoTranspose: return 'n';
-  case ConjTranspose: return 'c';
-  }
-  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
-}
-#endif  // AT_BUILD_WITH_BLAS
-
 #ifdef USE_FBGEMM
 fbgemm::matrix_op_t to_fbgemm(TransposeType trans) {
   switch (trans) {
-  case Transpose: return fbgemm::matrix_op_t::Transpose;
-  case NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
-  case ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
+    case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose;
+    case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
+    case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
   }
   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
 }
index 3a483e4..7002a57 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <ATen/native/DispatchStub.h>
+#include <ATen/native/LinearAlgebraUtils.h>
 #include <c10/util/complex.h>
 #include <c10/core/ScalarType.h>
 #include <c10/core/Scalar.h>
@@ -9,12 +10,6 @@ namespace at {
 namespace native {
 namespace cpublas {
 
-enum TransposeType {
-  Transpose,
-  NoTranspose,
-  ConjTranspose,
-};
-
 namespace internal {
 void normalize_last_dims(
   TransposeType transa, TransposeType transb,
index 0576bd6..57346ae 100644 (file)
@@ -1074,8 +1074,8 @@ static void addmm_impl_cpu_(
       result.scalar_type(), "addmm_impl_cpu_",
       [&]{
         at::native::cpublas::gemm(
-            transpose_a ? a.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
-            transpose_b ? b.is_conj() ? cpublas::ConjTranspose : cpublas::Transpose : cpublas::NoTranspose,
+            transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
+            transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
             m, n, k,
             alpha.to<scalar_t>(),
             a.data_ptr<scalar_t>(), lda,
index abbf82c..d4adb10 100644 (file)
 
 namespace at { namespace native {
 
+// Used as an interface between the different BLAS-like libraries
+enum class TransposeType {
+  NoTranspose,
+  Transpose,
+  ConjTranspose,
+};
+
+// Transforms TransposeType into the BLAS / LAPACK format
+static char to_blas(TransposeType trans) {
+  switch (trans) {
+    case TransposeType::Transpose: return 'T';
+    case TransposeType::NoTranspose: return 'N';
+    case TransposeType::ConjTranspose: return 'C';
+  }
+  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
 /*
  * Clones a Tensor so that the following conditions hold:
  * If we think of a Tensor of having size (B, M, N), where B is any number
@@ -19,8 +36,8 @@ namespace at { namespace native {
  * - Each (M, N) matrix is in column major form
  * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
  *   Then when laid out in memory, the M by N matrix starting at
- *   P.data_ptr()[b * M * N] is of the same corresponding batch as the M' by N'
- *   matrix starting at Q.data_ptr()[b * M' * N'].
+ *   P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
+ *   matrix starting at Q.data_ptr()[B * M' * N'].
  */
 static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
   // If src is already in batched column major format, then
index 6ae81e8..f121117 100644 (file)
@@ -317,8 +317,8 @@ void slow_conv_transpose2d_out_cpu_template(
           // Do GEMM (note: this is a bit confusing because gemm assumes
           // column-major matrices)
           cpublas::gemm(
-              cpublas::NoTranspose,
-              cpublas::Transpose,
+              TransposeType::NoTranspose,
+              TransposeType::Transpose,
               n,
               m,
               k,
@@ -360,8 +360,8 @@ void slow_conv_transpose2d_out_cpu_template(
           // column-major matrices)
           if (bias.defined()) {
             cpublas::gemm(
-                cpublas::Transpose,
-                cpublas::NoTranspose,
+                TransposeType::Transpose,
+                TransposeType::NoTranspose,
                 n_,
                 m_,
                 k_,
@@ -536,8 +536,8 @@ static void slow_conv_transpose2d_backward_out_cpu_template(
               ? grad_columns.data_ptr<scalar_t>()
               : grad_output_n.data_ptr<scalar_t>();
           cpublas::gemm(
-              cpublas::NoTranspose,
-              cpublas::NoTranspose,
+              TransposeType::NoTranspose,
+              TransposeType::NoTranspose,
               n,
               m,
               k,
@@ -738,8 +738,8 @@ void slow_conv_transpose2d_acc_grad_parameters_cpu(
                 ? columns.data_ptr<scalar_t>()
                 : grad_output_n.data_ptr<scalar_t>();
             cpublas::gemm(
-                cpublas::Transpose,
-                cpublas::NoTranspose,
+                TransposeType::Transpose,
+                TransposeType::NoTranspose,
                 n,
                 m,
                 k,
index dcde12e..9266047 100644 (file)
@@ -317,8 +317,8 @@ void slow_conv_transpose3d_out_cpu_template(
           // Do GEMM (note: this is a bit confusing because gemm assumes
           // column-major matrices)
           cpublas::gemm(
-              cpublas::NoTranspose,
-              cpublas::Transpose,
+              TransposeType::NoTranspose,
+              TransposeType::Transpose,
               n,
               m,
               k,
@@ -366,8 +366,8 @@ void slow_conv_transpose3d_out_cpu_template(
           // column-major matrices)
           if (bias.defined()) {
             cpublas::gemm(
-                cpublas::Transpose,
-                cpublas::NoTranspose,
+                TransposeType::Transpose,
+                TransposeType::NoTranspose,
                 n_,
                 m_,
                 k_,
@@ -579,8 +579,8 @@ void slow_conv_transpose3d_backward_out_cpu_template(
               ? grad_columns.data_ptr<scalar_t>()
               : grad_output_n.data_ptr<scalar_t>();
           cpublas::gemm(
-              cpublas::NoTranspose,
-              cpublas::NoTranspose,
+              TransposeType::NoTranspose,
+              TransposeType::NoTranspose,
               n,
               m,
               k,
@@ -819,8 +819,8 @@ void slow_conv_transpose3d_acc_grad_parameters_cpu(
                 ? columns.data_ptr<scalar_t>()
                 : grad_output_n.data_ptr<scalar_t>();
             cpublas::gemm(
-                cpublas::Transpose,
-                cpublas::NoTranspose,
+                TransposeType::Transpose,
+                TransposeType::NoTranspose,
                 n,
                 m,
                 k,
index 3274f13..ab99d23 100644 (file)
@@ -1,5 +1,3 @@
-
-
 #include <ATen/ATen.h>
 #include <ATen/native/CPUBlas.h>
 #include <ATen/native/DilatedConvolutionUtils.h>
@@ -273,8 +271,8 @@ void slow_conv_dilated_all_cpu_template(
             op(A) = 'n', op(B) = 'n', alpha=1, beta=1
         */
         cpublas::gemm(
-            /*transa=*/cpublas::NoTranspose,
-            /*transb=*/cpublas::NoTranspose,
+            /*transa=*/TransposeType::NoTranspose,
+            /*transb=*/TransposeType::NoTranspose,
             /*     m=*/columns.size(1),
             /*     n=*/nOutputPlane,
             /*     k=*/columns.size(0),
@@ -317,8 +315,8 @@ void slow_conv_dilated_all_cpu_template(
             op(A) = 'n', op(B) = 't', alpha=1, beta=0
          */
         cpublas::gemm(
-            /*transa=*/cpublas::NoTranspose,
-            /*transb=*/cpublas::Transpose,
+            /*transa=*/TransposeType::NoTranspose,
+            /*transb=*/TransposeType::Transpose,
             /*     m=*/columns.size(1),
             /*     n=*/columns.size(0),
             /*     k=*/nOutputPlane,
@@ -382,8 +380,8 @@ void slow_conv_dilated_all_cpu_template(
           op(B) = 'n', alpha=scale, beta=1
         */
         cpublas::gemm(
-            /*transa=*/cpublas::Transpose,
-            /*transb=*/cpublas::NoTranspose,
+            /*transa=*/TransposeType::Transpose,
+            /*transb=*/TransposeType::NoTranspose,
             /*     m=*/columns.size(0),
             /*     n=*/nOutputPlane,
             /*     k=*/columns.size(1),
index 4bf2327..1897ff8 100644 (file)
@@ -156,13 +156,13 @@ void gemm_core_(
     const scalar_t *b, int64_t ldb,
     scalar_t beta,
     scalar_t *c, int64_t ldc) {
-  if(transa == NoTranspose && transb == NoTranspose) {
+  if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) {
     return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
-  } else if(transa == Transpose && transb != Transpose) {
+  } else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) {
     gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
-  } else if(transa == NoTranspose && transb == Transpose) {
+  } else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) {
     gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
-  } else {  // transa == Transpose && transb == Transpose
+  } else {  // transa == TransposeType::Transpose && transb == TransposeType::Transpose
     gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
   }
 }
index 7fdc55d..c6225c6 100644 (file)
@@ -97,7 +97,7 @@ void magmaCholeskyBatched(
 
 template<class scalar_t>
 void magmaTriangularSolveBatched(
-    magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+    magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue);
 
@@ -667,29 +667,29 @@ void magmaCholeskyBatched<c10::complex<float>>(
 
 template<>
 void magmaTriangularSolveBatched<double>(
-    magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+    magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
-  magmablas_dtrsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
+  magmablas_dtrsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
 template<>
 void magmaTriangularSolveBatched<float>(
-    magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+    magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
-  magmablas_strsm_batched(MagmaLeft, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
+  magmablas_strsm_batched(side, uplo, trans, diag, m, n, 1, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
   AT_CUDA_CHECK(cudaGetLastError());
 }
 
 template<>
 void magmaTriangularSolveBatched<c10::complex<double>>(
-    magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+    magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     c10::complex<double>** dA_array, magma_int_t ldda, c10::complex<double>** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
   magmaDoubleComplex alpha({1, 0});
-  magmablas_ztrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha,
+  magmablas_ztrsm_batched(side, uplo, trans, diag, m, n, alpha,
     reinterpret_cast<magmaDoubleComplex**>(dA_array), ldda,
     reinterpret_cast<magmaDoubleComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
   AT_CUDA_CHECK(cudaGetLastError());
@@ -697,11 +697,11 @@ void magmaTriangularSolveBatched<c10::complex<double>>(
 
 template<>
 void magmaTriangularSolveBatched<c10::complex<float>>(
-    magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
+    magma_side_t side, magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
     c10::complex<float>** dA_array, magma_int_t ldda, c10::complex<float>** dB_array, magma_int_t lddb, magma_int_t batchsize,
     const MAGMAQueue& magma_queue) {
   magmaFloatComplex alpha({1, 0});
-  magmablas_ctrsm_batched(MagmaLeft, uplo, trans, diag, m, n, alpha,
+  magmablas_ctrsm_batched(side, uplo, trans, diag, m, n, alpha,
     reinterpret_cast<magmaFloatComplex**>(dA_array), ldda,
     reinterpret_cast<magmaFloatComplex**>(dB_array), lddb, batchsize, magma_queue.get_queue());
   AT_CUDA_CHECK(cudaGetLastError());
@@ -1224,8 +1224,15 @@ void checkMagmaInternalError(magma_int_t info, const std::string& magma_function
       ", when calling ", magma_function_name);
 }
 
+magma_trans_t to_magma(TransposeType trans) {
+  switch (trans) {
+    case TransposeType::NoTranspose: return MagmaNoTrans;
+    case TransposeType::Transpose: return MagmaTrans;
+    case TransposeType::ConjTranspose: return MagmaConjTrans;
+  }
+  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
 } // anonymous namespace
-
 #endif // USE_MAGMA
 
 #define ALLOCATE_ARRAY(name, type, size) \
@@ -1950,27 +1957,28 @@ REGISTER_CUDA_DISPATCH(lu_stub, &apply_lu);
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
-static void apply_triangular_solve_batched(Tensor& A, Tensor& b, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve_batched_magma(Tensor& A, Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) {
 #ifndef USE_MAGMA
 AT_ERROR("triangular_solve: MAGMA library not found in "
          "compilation. Please rebuild with MAGMA.");
 #else
   magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
-  magma_trans_t trans = transpose ? MagmaTrans : MagmaNoTrans;
-  trans = conjugate_transpose ? MagmaConjTrans : trans;
+  magma_trans_t trans = to_magma(transpose);
   magma_diag_t diag = unitriangular ? MagmaUnit : MagmaNonUnit;
+  magma_side_t side = left ? MagmaLeft : MagmaRight;
 
   auto A_data = A.data_ptr<scalar_t>();
   auto b_data = b.data_ptr<scalar_t>();
-  magma_int_t m = magma_int_cast(A.size(-2), "A.size(-2)");
-  magma_int_t n = magma_int_cast(A.size(-1), "A.size(-1)");
-  magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
+  // This allows to pass rectangular A and b when left = True
+  magma_int_t m = magma_int_cast(left ? A.size(-1) : b.size(-2), "m");
+  magma_int_t n = magma_int_cast(b.size(-1), "n");
   // magma returns early if m <= 0 || n <= 0 for magmaTriangularSolveBatched
   // magmaTriangularSolve is calling cuBLAS and it prints
   // ** On entry to DTRSM  parameter number 9 had an illegal value
   // so let's use proper lda parameter here
-  magma_int_t lda = std::max<magma_int_t>(1, m);
-  magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
+  magma_int_t lda = std::max<magma_int_t>(1, A.size(-2));
+  magma_int_t ldb = std::max<magma_int_t>(1, b.size(-2));
+  magma_int_t batch_size = magma_int_cast(batchCount(A), "batch_size");
 
   auto A_mat_stride = matrixStride(A);
   auto b_mat_stride = matrixStride(b);
@@ -1990,7 +1998,7 @@ AT_ERROR("triangular_solve: MAGMA library not found in "
   MAGMAQueue magma_queue(b.get_device());
 
   constexpr int64_t batch_limit = 65535;
-  // Compute as many batches of 65535 possible
+  // Compute as many batches of 65535 as possible
   // The number of "mini"-batches are floor(batch_size / batch_limit)
   // and these cover floor(batch_size / batch_limit) * batch_limit matrix solves
   int64_t mini_batches = batch_size / batch_limit;
@@ -2000,40 +2008,39 @@ AT_ERROR("triangular_solve: MAGMA library not found in "
     scalar_t** b_array_cur = &b_array[mini_idx];
 
     magmaTriangularSolveBatched<scalar_t>(
-        uplo, trans, diag, n, nrhs, A_array_cur,
-        lda, b_array_cur, lda, batch_limit, magma_queue);
+        side, uplo, trans, diag, m, n, A_array_cur,
+        lda, b_array_cur, ldb, batch_limit, magma_queue);
   }
 
   // Compute whatever is left = batch_size - floor(batch_size / batch_limit) * batch_limit
   // which concisely is equal to batch_size % batch_limit
   if (batch_size % batch_limit != 0) {
     magmaTriangularSolveBatched<scalar_t>(
-        uplo, trans, diag, n, nrhs, &A_array[mini_idx],
-        lda, &b_array[mini_idx], lda, batch_size % batch_limit, magma_queue);
+        side, uplo, trans, diag, m, n, &A_array[mini_idx],
+        lda, &b_array[mini_idx], ldb, batch_size % batch_limit, magma_queue);
   }
 #endif
 }
 
-void triangular_solve_batched_magma(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
-  (void)infos; // unused
+void triangular_solve_batched_magma(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
-    apply_triangular_solve_batched<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+    apply_triangular_solve_batched_magma<scalar_t>(A, B, left, upper, transpose, unitriangular);
   });
 }
 
-void triangular_solve_kernel(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+void triangular_solve_kernel(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   // For batches smaller than 8 and matrix sizes larger than 64x64 cuBLAS forloop is faster than batched version
   if (batchCount(A) <= 8 && A.size(-1) >= 64) {
-    triangular_solve_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+    triangular_solve_cublas(A, B, left, upper, transpose, unitriangular);
   } else {
 #ifndef USE_MAGMA
-    triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+    triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular);
 #else
     // cuBLAS batched is faster than MAGMA batched up until 512x512, after that MAGMA is faster
     if (A.size(-1) <= 512) {
-      triangular_solve_batched_cublas(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+      triangular_solve_batched_cublas(A, B, left, upper, transpose, unitriangular);
     } else {
-      triangular_solve_batched_magma(A, B, infos, upper, transpose, conjugate_transpose, unitriangular);
+      triangular_solve_batched_magma(A, B, left, upper, transpose, unitriangular);
     }
 #endif // USE_MAGMA
   }
@@ -2724,21 +2731,6 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda(const Tensor& self, bool som
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-#ifdef USE_MAGMA
-magma_trans_t _get_magma_trans(char trans) {
-  switch (trans) {
-    case 'N':
-      return MagmaNoTrans;
-    case 'T':
-      return MagmaTrans;
-    case 'C':
-      return MagmaConjTrans;
-    default:
-      return MagmaNoTrans;
-  }
-}
-#endif
-
 /*
   Solves the matrix equation A X = B
   X and B are n-by-nrhs matrices, A is represented using the LU factorization.
@@ -2754,14 +2746,14 @@ magma_trans_t _get_magma_trans(char trans) {
   For further details, please see the MAGMA documentation for magma_dgetrs_gpu.
 */
 template <typename scalar_t>
-static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
 #ifndef USE_MAGMA
   TORCH_CHECK(
       false,
       "Calling torch.lu_solve on a CUDA tensor requires compiling ",
       "PyTorch with MAGMA. lease rebuild with MAGMA.");
 #else
-  auto trans = _get_magma_trans(lapack_trans);
+  auto trans = to_magma(transpose);
   auto b_data = b.data_ptr<scalar_t>();
   auto lu_data = lu.data_ptr<scalar_t>();
 
@@ -2808,14 +2800,14 @@ static void apply_lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const
   For further details, please see the MAGMA documentation for magma_dgetrs_batched.
 */
 template <typename scalar_t>
-static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
 #ifndef USE_MAGMA
   TORCH_CHECK(
       false,
       "Calling torch.lu_solve on a CUDA tensor requires compiling ",
       "PyTorch with MAGMA. lease rebuild with MAGMA.");
 #else
-  auto trans = _get_magma_trans(lapack_trans);
+  auto trans = to_magma(transpose);
   auto b_data = b.data_ptr<scalar_t>();
   auto lu_data = lu.data_ptr<scalar_t>();
 
@@ -2869,34 +2861,20 @@ static void apply_lu_solve_batched_magma(const Tensor& b, const Tensor& lu, cons
 #endif
 }
 
-static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void lu_solve_batched_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_batched_magma", [&]{
-    apply_lu_solve_batched_magma<scalar_t>(b, lu, pivots, lapack_trans);
+    apply_lu_solve_batched_magma<scalar_t>(b, lu, pivots, trans);
   });
 }
 
-static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, char lapack_trans) {
+static void lu_solve_looped_magma(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_looped_magma", [&]{
-    apply_lu_solve_looped_magma<scalar_t>(b, lu, pivots, lapack_trans);
+    apply_lu_solve_looped_magma<scalar_t>(b, lu, pivots, trans);
   });
 }
 
-#if defined(USE_CUSOLVER) || defined(CUDART_VERSION)
-cublasOperation_t _get_cublas_trans(char trans) {
-  switch (trans) {
-    case 'N':
-      return CUBLAS_OP_N;
-    case 'T':
-      return CUBLAS_OP_T;
-    case 'C':
-      return CUBLAS_OP_C;
-    default:
-      return CUBLAS_OP_N;
-  }
-}
-#endif
 
-static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, char trans) {
+static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
   auto batch_size = batchCount(lu);
   auto m = lu.size(-2);
   auto b2 = b.size(-1);
@@ -2904,7 +2882,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten
   // heuristics determined from tests dicussed in https://github.com/pytorch/pytorch/pull/59148
 #ifdef USE_CUSOLVER
   if ((batch_size == 1 && m > 512) || (batch_size <= 8 && over_magma_dim_limit)) {
-    lu_solve_looped_cusolver(b, lu, pivots, _get_cublas_trans(trans));
+    lu_solve_looped_cusolver(b, lu, pivots, trans);
   }
 #else
   if (batch_size == 1) {
@@ -2913,7 +2891,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten
 #endif // ifdef USE_CUSOLVER
 #ifdef CUDART_VERSION
   else if ((batch_size > 2 && m <= 128) || (batch_size > 8 && over_magma_dim_limit)) {
-    lu_solve_batched_cublas(b, lu, pivots, _get_cublas_trans(trans));
+    lu_solve_batched_cublas(b, lu, pivots, trans);
   }
 #endif // ifdef CUDART_VERSION
   else {
@@ -2924,7 +2902,7 @@ static void lu_solve_trans_dispatch(const Tensor& b, const Tensor& lu, const Ten
 REGISTER_CUDA_DISPATCH(lu_solve_trans_stub, &lu_solve_trans_dispatch);
 
 static void lu_solve_dispatch(const Tensor& b, const Tensor& lu, const Tensor& pivots) {
-  lu_solve_trans_dispatch(b, lu, pivots, 'N');
+  lu_solve_trans_dispatch(b, lu, pivots, TransposeType::NoTranspose);
 }
 
 REGISTER_CUDA_DISPATCH(lu_solve_stub, &lu_solve_dispatch);
@@ -2975,7 +2953,7 @@ void gels_magma(const Tensor& a, Tensor& b, Tensor& infos) {
   });
 }
 
-void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) {
+void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& /*infos*/) {
   // The steps for using the QR decomposition for solving least squares problems
   // are outlined here https://en.wikipedia.org/wiki/QR_decomposition#Using_for_solution_to_linear_inverse_problems
   auto m = A.size(-2);
@@ -3012,15 +2990,13 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) {
     ormqr_kernel(A_broadcasted, tau_broadcasted, B, /*left=*/true, /*transpose=*/true);
 
     // Step 3: solve R X = B
-    bool upper = true;
-    bool transpose = false;
-    bool conjugate_transpose = false;
-    bool unitriangular = false;
     triangular_solve_kernel(
         const_cast<Tensor&>(A_broadcasted),
         const_cast<Tensor&>(B),
-        const_cast<Tensor&>(infos),
-        upper, transpose, conjugate_transpose, unitriangular);
+        /*left=*/true,
+        /*upper=*/true,
+        /*transpose=*/TransposeType::NoTranspose,
+        /*unitriangular=*/false);
   } else { // underdetermined case
     Tensor Ah = cloneBatchedColumnMajor(A.conj().transpose(-2, -1));
 
@@ -3036,15 +3012,15 @@ void linalg_lstsq_gels(const Tensor& A, const Tensor& B, const Tensor& infos) {
     Tensor Ah_broadcasted = is_fortran_contiguous ? Ah_expanded : cloneBatchedColumnMajor(Ah_expanded);
 
     // Step 2: R^H Z = B
-    bool upper = true;
-    bool transpose = true;
-    bool conjugate_transpose = true;
-    bool unitriangular = false;
+    const auto trans = Ah_broadcasted.is_complex() ? TransposeType::ConjTranspose
+                                                   : TransposeType::Transpose;
     triangular_solve_kernel(
         const_cast<Tensor&>(Ah_broadcasted),
         const_cast<Tensor&>(B),
-        const_cast<Tensor&>(infos),
-        upper, transpose, conjugate_transpose, unitriangular);
+        /*left=*/true,
+        /*upper=*/true,
+        /*transpose=*/trans,
+        /*unitriangular=*/false);
 
     // B matrix has the size max(m, n) x nrhs
     // triangular_solve_kernel writes its output into the first m rows of B leaving the rest untouched
index 13d67e5..f9b4322 100644 (file)
 namespace at {
 namespace native {
 
+cublasOperation_t to_cublas(TransposeType trans) {
+  switch (trans) {
+    case TransposeType::NoTranspose: return CUBLAS_OP_N;
+    case TransposeType::Transpose: return CUBLAS_OP_T;
+    case TransposeType::ConjTranspose: return CUBLAS_OP_C;
+  }
+  TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
+}
+
 // Some cuBLAS and cuSOLVER batched routines require input to be a device array of pointers to device individual matrices
 // 'input' must be a contiguous tensor
 template <typename scalar_t>
@@ -70,10 +79,11 @@ void geqrf_batched_cublas(const Tensor& input, const Tensor& tau) {
 }
 
 template <typename scalar_t>
-static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
 #ifndef CUDART_VERSION
   TORCH_CHECK(false, "lu_solve: cuBLAS backend for lu_solve is not available.")
 #else
+  const auto trans = to_cublas(transpose);
 
   auto pivots_data = pivots.data_ptr<int>();
   auto batch_size = cuda_int_cast(batchCount(lu), "batch_size");;
@@ -94,29 +104,29 @@ static void apply_lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, con
 #endif
 }
 
-void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType trans) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(lu.scalar_type(), "lu_solve_cublas", [&]{
     apply_lu_solve_batched_cublas<scalar_t>(b, lu, pivots, trans);
   });
 }
 
 template <typename scalar_t>
-static void apply_triangular_solve(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
-  cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
-  trans = conjugate_transpose ? CUBLAS_OP_C : trans;
+  const auto trans = to_cublas(transpose);
   cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
-  cublasSideMode_t side = CUBLAS_SIDE_LEFT;
+  cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
 
   auto A_data = A.data_ptr<scalar_t>();
   auto B_data = B.data_ptr<scalar_t>();
   auto A_mat_stride = matrixStride(A);
   auto B_mat_stride = matrixStride(B);
   auto batch_size = batchCount(A);
-  auto m = cuda_int_cast(A.size(-2), "m");
-  auto n = cuda_int_cast(A.size(-1), "n");
-  auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
-  auto lda = std::max<int>(1, m);
+  // This allows to pass rectangular A and B when left = True
+  auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
+  auto n = cuda_int_cast(B.size(-1), "n");
+  auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
+  auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
 
   auto alpha = scalar_t{1};
 
@@ -124,30 +134,29 @@ static void apply_triangular_solve(Tensor& A, Tensor& B, bool upper, bool transp
     scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
     scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
     auto handle = at::cuda::getCurrentCUDABlasHandle();
-    at::cuda::blas::trsm(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_working_ptr, lda, B_working_ptr, lda);
+    at::cuda::blas::trsm(handle, side, uplo, trans, diag, m, n, &alpha, A_working_ptr, lda, B_working_ptr, ldb);
   }
 }
 
-void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
-  (void)infos; // unused
+void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
-    apply_triangular_solve<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+    apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
   });
 }
 
 template <typename scalar_t>
-static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
+static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
-  cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
-  trans = conjugate_transpose ? CUBLAS_OP_C : trans;
+  const auto trans = to_cublas(transpose);
   cublasDiagType_t diag = unitriangular ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
-  cublasSideMode_t side = CUBLAS_SIDE_LEFT;
+  cublasSideMode_t side = left ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT;
 
   auto batch_size = cuda_int_cast(batchCount(A), "batch_size");
-  auto m = cuda_int_cast(A.size(-2), "m");
-  auto n = cuda_int_cast(A.size(-1), "n");
-  auto nrhs = cuda_int_cast(B.size(-1), "nrhs");
-  auto lda = std::max<int>(1, m);
+  // This allows to pass rectangular A and B when left = True
+  auto m = cuda_int_cast(left ? A.size(-1) : B.size(-2), "m");
+  auto n = cuda_int_cast(B.size(-1), "n");
+  auto lda = std::max<int>(1, cuda_int_cast(A.size(-2), "lda"));
+  auto ldb = std::max<int>(1, cuda_int_cast(B.size(-2), "ldb"));
 
   auto alpha = scalar_t{1};
 
@@ -158,13 +167,12 @@ static void apply_triangular_solve_batched(Tensor& A, Tensor& B, bool upper, boo
   auto B_ptr_array_data = reinterpret_cast<scalar_t**>(B_ptr_array.data_ptr());
 
   auto handle = at::cuda::getCurrentCUDABlasHandle();
-  at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, n, nrhs, &alpha, A_ptr_array_data, lda, B_ptr_array_data, lda, batch_size);
+  at::cuda::blas::trsmBatched(handle, side, uplo, trans, diag, m, n, &alpha, A_ptr_array_data, lda, B_ptr_array_data, ldb, batch_size);
 }
 
-void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular) {
-  (void)infos; // unused
+void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cuda", [&]{
-    apply_triangular_solve_batched<scalar_t>(A, B, upper, transpose, conjugate_transpose, unitriangular);
+    apply_triangular_solve_batched<scalar_t>(A, B, left, upper, transpose, unitriangular);
   });
 }
 
@@ -1297,8 +1305,9 @@ void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor&
   }
 }
 
-void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans) {
+void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose) {
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(b.scalar_type(), "lu_solve_cusolver", [&] {
+    const auto trans = to_cublas(transpose);
     int n = cuda_int_cast(lu.size(-2), "n");
     int nrhs = cuda_int_cast(b.size(-1), "nrhs");
     auto batch_size = batchCount(lu);
index 1a30187..72d2f65 100644 (file)
@@ -36,10 +36,10 @@ namespace at {
 namespace native {
 
 void geqrf_batched_cublas(const Tensor& input, const Tensor& tau);
-void triangular_solve_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular);
-void triangular_solve_batched_cublas(Tensor& A, Tensor& B, Tensor& infos, bool upper, bool transpose, bool conjugate_transpose, bool unitriangular);
+void triangular_solve_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
+void triangular_solve_batched_cublas(Tensor& A, Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
 void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos);
-void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans);
+void lu_solve_batched_cublas(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose);
 
 #ifdef USE_CUSOLVER
 
@@ -60,7 +60,7 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other,
 Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
 
 void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
-void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, cublasOperation_t trans);
+void lu_solve_looped_cusolver(const Tensor& b, const Tensor& lu, const Tensor& pivots, TransposeType transpose);
 
 void lu_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
 
index 5912111..6ff0ce5 100644 (file)
@@ -4926,17 +4926,6 @@ class TestLinalg(TestCase):
             run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular)  # broadcasting A
             run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular)  # broadcasting A & b
 
-    @onlyCPU
-    @skipCPUIfNoLapack
-    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
-    def test_triangular_solve_singular(self, device, dtype):
-        b = torch.rand(3, 1, dtype=dtype, device=device)
-        A = torch.eye(3, 3, dtype=dtype, device=device)
-        A[-1, -1] = 0  # Now A is singular
-        err_str = r"triangular_solve: The diagonal element 3 is zero"
-        with self.assertRaisesRegex(RuntimeError, err_str):
-            torch.triangular_solve(b, A)
-
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
     @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
index fe50bad..d0728dd 100644 (file)
@@ -9621,6 +9621,10 @@ with the default keyword arguments.
 batches of 2D matrices. If the inputs are batches, then returns
 batched outputs `X`
 
+If the diagonal of :attr:`A` contains zeros or elements that are very close to zero and
+:attr:`unitriangular`\ `= False` (default) or if the input matrix is badly conditioned,
+the result may contain `NaN` s.
+
 Supports input of float, double, cfloat and cdouble data types.
 
 Args:
index 00bbeb3..72af4fc 100644 (file)
@@ -8156,7 +8156,7 @@ op_db: List[OpInfo] = [
            supports_out=False,
            sample_inputs_func=sample_inputs_legacy_solve,
            check_batched_gradgrad=False,
-           decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
+           decorators=[skipCUDAIfNoMagma]),
     UnaryUfuncInfo('trunc',
                    aliases=('fix', ),
                    ref=np.trunc,