Remove TH/THC link for cholesky_solve (#15691)
authorvishwakftw <cs15btech11043@iith.ac.in>
Fri, 4 Jan 2019 14:18:35 +0000 (06:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 4 Jan 2019 14:24:17 +0000 (06:24 -0800)
Summary:
Changelog:
- Remove TH/THC binding
- Port single matrix case to ATen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15691

Differential Revision: D13579317

Pulled By: soumith

fbshipit-source-id: 63a55606c656396e777e8e6828acd2ef88ed1543

aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/TH/generic/THLapack.cpp
aten/src/TH/generic/THLapack.h
aten/src/TH/generic/THTensorLapack.cpp
aten/src/TH/generic/THTensorLapack.h
aten/src/THC/generic/THCTensorMathMagma.cu
aten/src/THC/generic/THCTensorMathMagma.h

index 8ed943c..74cf9f4 100644 (file)
     - THTensor* self
 ]]
 [[
-  name: _th_potrs_single
-  cname: potrs
-  types:
-    - Float
-    - Double
-  backends:
-    - CPU
-    - CUDA
-  variants:
-    - function
-  return: argument 0
-  arguments:
-    - arg: THTensor* result
-      output: True
-    - THTensor* self
-    - THTensor* input2
-    - arg: bool upper
-      if_true: U
-      if_false: L
-      default: U
-]]
-[[
   name: _th_potri
   cname: potri
   types:
index 15c6048..a01f33e 100644 (file)
@@ -263,43 +263,51 @@ static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, std::vector<i
 
   auto A_data = A.data<scalar_t>();
   auto b_data = b.data<scalar_t>();
-  auto A_mat_stride = matrixStride(A);
-  auto b_mat_stride = matrixStride(b);
-
-  auto batch_size = batchCount(A);
   auto n = A.size(-2);
   auto nrhs = b.size(-1);
 
-  for (int64_t i = 0; i < batch_size; i++) {
-    int info;
-    scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
-    scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
-    lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
-    infos[i] = info;
-    if (info != 0) {
-      return;
+  int info;
+  if (b.dim() == 2) {
+    lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n, b_data, n, &info);
+    infos[0] = info;
+  } else {
+    auto A_mat_stride = matrixStride(A);
+    auto b_mat_stride = matrixStride(b);
+    auto batch_size = batchCount(A);
+    for (int64_t i = 0; i < batch_size; i++) {
+      scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
+      scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
+      lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
+      infos[i] = info;
+      if (info != 0) {
+        return;
+      }
     }
   }
 #endif
 }
 
 Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
-  std::vector<int64_t> infos(batchCount(self), 0);
   auto self_working_copy = cloneBatchedColumnMajor(self);
   auto A_working_copy = cloneBatchedColumnMajor(A);
+  std::vector<int64_t> infos(batchCount(self), 0);
   AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky_solve", [&]{
     apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
   });
-  batchCheckErrors(infos, "cholesky_solve");
+  if (self.dim() > 2) {
+    batchCheckErrors(infos, "cholesky_solve");
+  } else {
+    singleCheckErrors(infos[0], "cholesky_solve");
+  }
   return self_working_copy;
 }
 
 // Supports arbitrary batch dimensions for self and A
 Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
-  if (self.dim() <= 2 && A.dim() <= 2) {
-    return at::legacy::th::_th_potrs_single(self, A, upper);
-  }
-
+  AT_CHECK(self.dim() >= 2,
+           "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+  AT_CHECK(A.dim() >= 2,
+           "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
   Tensor self_broadcasted, A_broadcasted;
   std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
   return at::_cholesky_solve_helper(self_broadcasted, A_broadcasted, upper);
@@ -309,7 +317,8 @@ Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A,
   AT_CHECK(self.dim() == 2 && A.dim() == 2,
            "torch.cholesky_solve() with the `out` keyword does not support batching. "
            "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
-  return at::legacy::th::_th_potrs_single_out(result, self, A, upper);
+  result = at::_cholesky_solve_helper(self, A, upper);
+  return result;
 }
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
index f7c09af..a6f1dd2 100644 (file)
@@ -51,6 +51,13 @@ void magmaGetriBatched(
 }
 
 template<class scalar_t>
+void magmaCholeskySolve(
+    magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t* dA, magma_int_t ldda,
+    scalar_t* dB, magma_int_t lddb, magma_int_t* info) {
+  AT_ERROR("cholesky_solve only takes float or double Tensors");
+}
+
+template<class scalar_t>
 void magmaCholeskySolveBatched(
     magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
     scalar_t** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
@@ -134,6 +141,20 @@ void magmaGetriBatched<float>(
 }
 
 template<>
+void magmaCholeskySolve<double>(
+    magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double* dA, magma_int_t ldda,
+    double* dB, magma_int_t lddb, magma_int_t* info) {
+  magma_dpotrs_gpu(uplo, n, nrhs, dA, ldda, dB, lddb, info);
+}
+
+template<>
+void magmaCholeskySolve<float>(
+    magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, float* dA, magma_int_t ldda,
+    float* dB, magma_int_t lddb, magma_int_t* info) {
+  magma_spotrs_gpu(uplo, n, nrhs, dA, ldda, dB, lddb, info);
+}
+
+template<>
 void magmaCholeskySolveBatched<double>(
     magma_uplo_t uplo, magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
     double** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
@@ -326,32 +347,38 @@ AT_ERROR("cholesky_solve: MAGMA library not found in "
 
   auto A_data = A.data<scalar_t>();
   auto b_data = b.data<scalar_t>();
-  auto A_mat_stride = matrixStride(A);
-  auto b_mat_stride = matrixStride(b);
-
-  magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
   magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
   magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
 
-  magma_int_t info_tmp;
-  scalar_t** A_array;
-  scalar_t** b_array;
+  int info_tmp;
+  if (b.dim() == 2) {
+    magmaCholeskySolve<scalar_t>(uplo, n, nrhs, A_data, n,
+                                 b_data, n, &info_tmp);
+    info = info_tmp;
+  } else {
+    auto A_mat_stride = matrixStride(A);
+    auto b_mat_stride = matrixStride(b);
+    magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
 
-  ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
-  ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
+    scalar_t** A_array;
+    scalar_t** b_array;
 
-  // Set up the created arrays
-  for (int64_t i = 0; i < batch_size; i++) {
-    A_array[i] = &A_data[i * A_mat_stride];
-    b_array[i] = &b_data[i * b_mat_stride];
-  }
+    ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
+    ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
 
-  MAGMAQueue magma_queue(b.get_device());
-  magmaCholeskySolveBatched<scalar_t>(
-      uplo, n, nrhs, A_array, n, b_array, n,
-      info_tmp, batch_size, magma_queue);
+    // Set up the created arrays
+    for (int64_t i = 0; i < batch_size; i++) {
+      A_array[i] = &A_data[i * A_mat_stride];
+      b_array[i] = &b_data[i * b_mat_stride];
+    }
 
-  info = info_tmp;
+    MAGMAQueue magma_queue(b.get_device());
+    magmaCholeskySolveBatched<scalar_t>(
+        uplo, n, nrhs, A_array, n, b_array, n,
+        info_tmp, batch_size, magma_queue);
+
+    info = info_tmp;
+  }
 #endif
 }
 
index a0c3fb4..28a4f29 100644 (file)
@@ -21,8 +21,6 @@ TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, in
 TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
 TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
 TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
-TH_EXTERNC void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
-TH_EXTERNC void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
 TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
 TH_EXTERNC void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
 TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
@@ -149,20 +147,6 @@ void THLapack_(getri)(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, in
 #endif
 }
 
-/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
-void THLapack_(potrs)(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info)
-{
-#ifdef  USE_LAPACK
-#if defined(TH_REAL_IS_DOUBLE)
-  dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
-#else
-  spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
-#endif
-#else
-  THError("potrs: Lapack library not found in compile time\n");
-#endif
-}
-
 /* Cholesky factorization based Matrix Inverse */
 void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info)
 {
index 8942abe..5c65140 100644 (file)
@@ -21,8 +21,6 @@ TH_API void THLapack_(getri)(int n, scalar_t *a, int lda, int *ipiv, scalar_t *w
 /* Positive Definite matrices */
 /* Matrix inverse based on Cholesky factorization */
 TH_API void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info);
-/* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */
-TH_API void THLapack_(potrs)(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
 /* Cholesky factorization with complete pivoting. */
 TH_API void THLapack_(pstrf)(char uplo, int n, scalar_t *a, int lda, int *piv, int *rank, scalar_t tol, scalar_t *work, int *info);
 
index ec127f9..1fbea8e 100644 (file)
@@ -600,54 +600,6 @@ void THTensor_(copyUpLoTriangle)(THTensor *a, const char *uplo)
   }
 }
 
-void THTensor_(potrs)(THTensor *rb_, THTensor *b, THTensor *a, const char *uplo)
-{
-  int free_b = 0;
-  if (b == NULL) b = rb_;
-
-  THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 2, "A should have 2 dimensions, but has %d",
-      THTensor_nDimensionLegacyAll(a));
-  THArgCheck(THTensor_nDimensionLegacyAll(b) == 1 || THTensor_nDimensionLegacyAll(b) == 2, 1, "B should have 1 or 2 "
-      "dimensions, but has %d", THTensor_nDimensionLegacyAll(b));
-  THArgCheck(a->size(0) == a->size(1), 2, "A should be square, but is %ldx%ld",
-      a->size(0), a->size(1));
-  THArgCheck(a->size(0) == b->size(0), 2, "A,B size incompatible - A has %ld "
-      "rows, B has %ld", a->size(0), b->size(0));
-
-  if (THTensor_nDimensionLegacyAll(b) == 1) {
-    b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
-            b->stride(0), 1, 0);
-    free_b = 1;
-  }
-
-  int n, nrhs, lda, ldb, info;
-  THTensor *ra__; // working version of A matrix to be passed into lapack TRTRS
-  THTensor *rb__; // working version of B matrix to be passed into lapack TRTRS
-
-  ra__ = THTensor_(cloneColumnMajor)(NULL, a);
-  rb__ = THTensor_(cloneColumnMajor)(rb_, b);
-
-  n    = (int)ra__->size(0);
-  nrhs = (int)rb__->size(1);
-  lda  = n;
-  ldb  = n;
-
-  THLapack_(potrs)(uplo[0], n, nrhs, ra__->data<scalar_t>(),
-                   lda, rb__->data<scalar_t>(), ldb, &info);
-
-
-  THLapackCheckWithCleanup("Lapack Error in %s : A(%d,%d) is zero, singular A",
-                           THCleanup(
-                               c10::raw::intrusive_ptr::decref(ra__);
-                               c10::raw::intrusive_ptr::decref(rb__);
-                               if (free_b) c10::raw::intrusive_ptr::decref(b);),
-                           "potrs", info, info);
-
-  if (free_b) c10::raw::intrusive_ptr::decref(b);
-  c10::raw::intrusive_ptr::decref(ra__);
-  THTensor_(freeCopyTo)(rb__, rb_);
-}
-
 void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo)
 {
   if (a == NULL) a = ra_;
index dfb8171..2444307 100644 (file)
@@ -10,7 +10,6 @@ TH_API void THTensor_(gesdd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTens
 TH_API void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a,
                               const char *some, const char* compute_uv);
 TH_API void THTensor_(getri)(THTensor *ra_, THTensor *a);
-TH_API void THTensor_(potrs)(THTensor *rb_, THTensor *b_, THTensor *a_,  const char *uplo);
 TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo);
 TH_API void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a);
 TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a);
index 8cf7b8a..418bb7e 100644 (file)
@@ -539,38 +539,6 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char
 #endif
 }
 
-void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, THCTensor *a, const char *uplo)
-{
-#ifdef USE_MAGMA
-  THArgCheck(a->size(0) == a->size(1), 2, "A should be square");
-
-  int64_t n = a->size(0);
-  int64_t nrhs = b->size(1);
-  magma_uplo_t ul = uplo[0] == 'U' ?  MagmaUpper : MagmaLower;
-
-  THCTensor *b_ = THCTensor_(newColumnMajor)(state, rb_, b);
-  scalar_t *b_data = THCTensor_(data)(state, b_);
-  THCTensor *a_ = THCTensor_(newColumnMajor)(state, a, a);
-  scalar_t *a_data = THCTensor_(data)(state, a_);
-
-  int info;
-#if defined(THC_REAL_IS_FLOAT)
-  magma_spotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
-#else
-  magma_dpotrs_gpu(ul, n, nrhs, a_data, n, b_data, n, &info);
-#endif
-
-  // check error value
-  if (info < 0)
-    THError("MAGMA potrs : Argument %d : illegal value", -info);
-
-  THCTensor_(freeCopyTo)(state, b_, rb_);
-  THCTensor_(free)(state, a_);
-#else
-  THError(NoMagma(potrs));
-#endif
-}
-
 void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_)
 {
 #ifdef USE_MAGMA
index 548458b..e3870f2 100644 (file)
@@ -16,7 +16,6 @@ THC_API void THCTensor_(gesdd2)(THCState *state, THCTensor *ru_, THCTensor *rs_,
                                 const char *some, const char* compute_uv);
 THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
 THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
-THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
 THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
 THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);